from abc import ABC, abstractmethod
from collections import deque

import torch

from dcrl.utils.format_utils import (
    DictList,
    batch_tensor_obs_squeeze,
    vec_obs_as_tensor,
)


class BaseAlgo(ABC):
    def __init__(
        self,
        envs,
        actor_model,
        critic_model,
        device,
        num_frames,
        gamma,
        lr,
        gae_lambda,
        entropy_coef,
        value_loss_coef,
        max_grad_norm,
        reshape_reward_fn=None,
        reshape_adv_fn=None,
    ):
        self.vec_env = envs
        self.actor_model = actor_model
        self.critic_model = critic_model
        self.value_size = self.critic_model.value_size
        self.device = device
        self.num_frames = num_frames
        self.gamma = gamma
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        if reshape_reward_fn is None:
            reshape_reward_fn = lambda reward: reward
        self.reshape_reward_fn = reshape_reward_fn
        if reshape_adv_fn is None:
            reshape_adv_fn = lambda adv: adv[..., 0]
        self.reshape_adv_fn = reshape_adv_fn

        assert self.actor_model.recurrent == self.critic_model.recurrent
        self.model_recurrent = self.actor_model.recurrent

        self.actor_model.to(self.device)
        self.actor_model.train()
        self.critic_model.to(self.device)
        self.critic_model.train()

        self.num_envs = envs.num_envs
        self.batch_size = self.num_frames * self.num_envs

        shape = (self.num_frames, self.num_envs)
        self.obs, _ = self.vec_env.reset()
        self.obss = [None] * self.num_frames
        self.last_not_done = torch.ones(self.num_envs, device=self.device)
        self.not_dones = torch.zeros(*shape, device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)
        self.values = torch.zeros((*shape, self.value_size), device=self.device)
        self.advantages = torch.zeros((*shape, self.value_size), device=self.device)
        self.last_actor_memory = torch.zeros((self.num_envs, self.actor_model.memory_size), device=self.device)
        self.actor_memories = torch.zeros((*shape, self.actor_model.memory_size), device=self.device)
        self.last_critic_memory = torch.zeros((self.num_envs, self.critic_model.memory_size), device=self.device)
        self.critic_memories = torch.zeros((*shape, self.critic_model.memory_size), device=self.device)

        self.log_episode_return = torch.zeros(self.num_envs, device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_envs, device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_envs, device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_envs
        self.log_reshaped_return = [0] * self.num_envs
        self.log_num_frames = [0] * self.num_envs
        self.log_avg_return = deque([0] * 100, maxlen=100)

    def collect_experiences(self, use_dual=False, drop_func=None):
        for i in range(self.num_frames):
            tensor_obs = vec_obs_as_tensor(self.obs, device=self.device)
            if drop_func is not None:
                tensor_obs = drop_func(tensor_obs)
            with torch.no_grad():
                dist, actor_memory = self.actor_model(tensor_obs, self.last_actor_memory)
                value, critic_memory = self.critic_model(tensor_obs, self.last_critic_memory)
            action = dist.sample()

            obs, reward, terminated, truncated, _ = self.vec_env.step(action.cpu().numpy())
            done = terminated | truncated
            not_done = 1 - torch.tensor(done, device=self.device, dtype=torch.float)
            reward = torch.tensor(reward, device=self.device, dtype=torch.float)

            self.obss[i] = tensor_obs
            self.actor_memories[i] = self.last_actor_memory
            self.critic_memories[i] = self.last_critic_memory
            self.not_dones[i] = self.last_not_done
            self.actions[i] = action
            self.values[i] = value
            self.rewards[i] = self.reshape_reward_fn(reward)
            self.log_probs[i] = dist.log_prob(action)

            self.obs = obs
            self.last_not_done = not_done
            self.last_actor_memory = actor_memory * self.last_not_done.unsqueeze(-1)
            self.last_critic_memory = critic_memory * self.last_not_done.unsqueeze(-1)

            self.log_episode_return += reward
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += 1

            for j, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[j].item())
                    self.log_reshaped_return.append(self.log_episode_reshaped_return[j].item())
                    self.log_num_frames.append(self.log_episode_num_frames[j].item())
                    self.log_avg_return.append(self.log_episode_return[j].item())

            self.log_episode_return *= self.last_not_done
            self.log_episode_reshaped_return *= self.last_not_done
            self.log_episode_num_frames *= self.last_not_done

        tensor_obs = vec_obs_as_tensor(self.obs, device=self.device)
        if drop_func is not None:
            tensor_obs = drop_func(tensor_obs)
        with torch.no_grad():
            last_value, _ = self.critic_model(tensor_obs, self.last_critic_memory)

        self.advantages = self.calc_advantages_gae(
            self.rewards,
            self.not_dones,
            self.values,
            self.last_not_done,
            last_value,
            self.gamma,
            self.gae_lambda,
        )
        self.returnns = self.values + self.advantages

        exps = DictList()
        exps.obs = batch_tensor_obs_squeeze(self.obss)
        exps.actor_memory = self.actor_memories.transpose(0, 1).reshape(-1, *self.actor_memories.shape[2:])
        exps.critic_memory = self.critic_memories.transpose(0, 1).reshape(-1, *self.critic_memories.shape[2:])
        exps.memory_mask = self.not_dones.transpose(0, 1).reshape(-1, 1)
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1, self.value_size)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1, self.value_size)
        exps.advantage = self.reshape_adv_fn(exps.advantage)
        exps.returnn = self.returnns.transpose(0, 1).reshape(-1, self.value_size)

        dual_exps = None
        if use_dual:
            dual_exps = DictList()
            dual_exps.obs = self.obss
            dual_exps.actor_memory = self.actor_memories
            dual_exps.critic_memory = self.critic_memories
            dual_exps.action = self.actions
            dual_exps.reward = self.rewards
            dual_exps.value = self.values
            dual_exps.returnn = self.returnns
            dual_exps.not_done = torch.cat((self.not_dones[1:], self.last_not_done.unsqueeze(0)), dim=0)
            dual_exps.memory_mask = self.not_dones

        keep = max(self.log_done_counter, self.num_envs)
        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.batch_size,
            "avg_return": self.log_avg_return,
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_envs :]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_envs :]
        self.log_num_frames = self.log_num_frames[-self.num_envs :]

        return exps, logs, dual_exps

    @abstractmethod
    def update_parameters(self, exps):
        pass

    @staticmethod
    def calc_advantages_gae(rewards, not_dones, values, last_not_done, last_value, gamma=1.0, gae_lambda=1.0):
        if len(rewards.shape) < len(values.shape):
            rewards = rewards.unsqueeze(-1)
            not_dones = not_dones.unsqueeze(-1)
            last_not_done = last_not_done.unsqueeze(-1)

        advantages = torch.zeros_like(values)
        num_frames = rewards.shape[0]
        for i in reversed(range(num_frames)):
            next_not_done = not_dones[i + 1] if i < num_frames - 1 else last_not_done
            next_value = values[i + 1] if i < num_frames - 1 else last_value
            next_advantage = advantages[i + 1] if i < num_frames - 1 else 0

            delta = rewards[i] + gamma * next_value * next_not_done - values[i]
            advantages[i] = delta + gamma * gae_lambda * next_advantage * next_not_done
        return advantages
