import time

import gymnasium as gym
import torch

from dcrl.utils.env_utils import make_minigrid_env, make_miniworld_env
from dcrl.utils.format_utils import list_obs_as_tensor


class Evaluation:
    def __init__(
        self,
        env_name,
        seed,
        num_envs,
        view_size,
        actor_model,
        device,
        episodes=100,
    ):
        self.env_name = env_name
        self.seed = seed
        self.num_envs = num_envs
        self.view_size = view_size
        self.actor_model = actor_model
        self.device = device
        self.episodes = episodes
        self.argmax = False

        if self.actor_model.recurrent:
            self.memories = torch.zeros((self.num_envs, self.actor_model.memory_size), device=device)
        else:
            self.memories = None

    def make_envs(self):
        if "MiniGrid" in self.env_name:
            envs = gym.vector.SyncVectorEnv(
                [
                    make_minigrid_env(
                        self.env_name,
                        self.seed + 1000 * i,
                        i,
                        view_size=self.view_size,
                    )
                    for i in range(self.num_envs)
                ]
            )
        elif "MiniWorld" in self.env_name:
            envs = gym.vector.SyncVectorEnv(
                [
                    make_miniworld_env(
                        self.env_name,
                        self.seed + 1000 * i,
                        i,
                    )
                    for i in range(self.num_envs)
                ]
            )
        else:
            raise Exception(f"Unrecognized type of env name {self.env_name}")
        return envs

    def get_actions(self, obss, drop_func=None):
        preprocessed_obss = list_obs_as_tensor(obss, device=self.device)
        if drop_func is not None:
            preprocessed_obss = drop_func(preprocessed_obss)
            assert torch.count_nonzero(preprocessed_obss["state"]) == 0

        with torch.no_grad():
            dist, actor_memory = self.actor_model(preprocessed_obss, self.memories)

        if self.argmax:
            actions = dist.probs.max(1, keepdim=True)[1]
        else:
            actions = dist.sample()

        actions = actions.cpu().numpy()
        return actions

    def update_memories(self, dones):
        if self.actor_model.recurrent:
            not_dones = 1 - torch.tensor(dones, dtype=torch.float, device=self.device).unsqueeze(1)
            self.memories *= not_dones

    def eval(self, drop_func=None):
        logs = {"num_frames_per_episode": [], "return_per_episode": []}
        log_done_counter = 0
        log_episode_return = torch.zeros(self.num_envs, device=self.device)
        log_episode_num_frames = torch.zeros(self.num_envs, device=self.device)
        self.update_memories([1] * self.num_envs)

        start_time = time.time()
        envs = self.make_envs()
        obss, _ = envs.reset()
        while log_done_counter < self.episodes:
            actions = self.get_actions(obss, drop_func)
            obss, rewards, terminateds, truncateds, _ = envs.step(actions)
            dones = terminateds | truncateds
            self.update_memories(dones)

            log_episode_return += torch.tensor(rewards, device=self.device, dtype=torch.float)
            log_episode_num_frames += 1

            for i, done in enumerate(dones):
                if done:
                    log_done_counter += 1
                    logs["return_per_episode"].append(log_episode_return[i].item())
                    logs["num_frames_per_episode"].append(log_episode_num_frames[i].item())

            not_dones = 1 - torch.tensor(dones, device=self.device, dtype=torch.float)
            log_episode_return *= not_dones
            log_episode_num_frames *= not_dones

        envs.close()
        end_time = time.time()
        num_frames = sum(logs["num_frames_per_episode"])
        fps = num_frames / (end_time - start_time)
        duration = int(end_time - start_time)

        logs["FPS"] = fps
        logs["duration"] = duration
        return logs
