import numpy as np
import torch

from dcrl.utils.format_utils import DictList, batch_tensor_obs_squeeze


class DualAlgo:
    def __init__(
        self,
        actor_model,
        critic_model,
        optimizer,
        parameters,
        device,
        gamma,
        max_grad_norm,
        num_envs,
        num_frames,
        dual_epochs,
        dual_coef,
        dual_num_mini_batches,
        max_nlogp,
    ):
        self.actor_model = actor_model
        self.critic_model = critic_model
        self.value_size = self.critic_model.value_size
        self.optimizer = optimizer
        self.parameters = parameters
        self.device = device
        self.gamma = gamma
        self.max_grad_norm = max_grad_norm
        self.num_envs = num_envs
        self.num_frames = num_frames
        self.batch_size = self.num_frames * self.num_envs
        self.max_nlogp = max_nlogp
        self.epochs = dual_epochs
        self.dual_coef = dual_coef
        self.num_mini_batches = dual_num_mini_batches
        self.clip = 1
        self.eps = 1e-11

        assert self.actor_model.recurrent == self.critic_model.recurrent
        self.model_recurrent = self.actor_model.recurrent

        shape = (self.num_frames, self.num_envs)
        self.obss = [None] * self.num_frames
        self.actor_memories = torch.zeros((*shape, self.actor_model.memory_size), device=self.device)
        self.critic_memories = torch.zeros((*shape, self.critic_model.memory_size), device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.values = torch.zeros((*shape, self.value_size), device=self.device)
        self.returns = torch.zeros((*shape, self.value_size), device=self.device)
        self.not_dones = torch.zeros(*shape, device=self.device)
        self.memory_mask = torch.zeros(*shape, device=self.device)

        self._next_idx = 0

    def store_experiences(self, exps):
        len_exps = len(exps)
        next_idx = self._next_idx
        self._next_idx = (self._next_idx + len_exps) % self.num_frames

        self.obss[next_idx : next_idx + len_exps] = exps.obs
        self.actor_memories[next_idx : next_idx + len_exps] = exps.actor_memory
        self.critic_memories[next_idx : next_idx + len_exps] = exps.critic_memory
        self.actions[next_idx : next_idx + len_exps] = exps.action
        self.rewards[next_idx : next_idx + len_exps] = exps.reward
        self.values[next_idx : next_idx + len_exps] = exps.value
        self.returns[next_idx : next_idx + len_exps] = exps.returnn
        self.not_dones[next_idx : next_idx + len_exps] = exps.not_done
        self.memory_mask[next_idx : next_idx + len_exps] = exps.memory_mask

    def collect_experiences(self):
        ret = 0
        mc_returns = torch.zeros_like(self.rewards)
        for i in reversed(range(self.num_frames)):
            ret = mc_returns[i] = self.rewards[i] + self.gamma * ret * self.not_dones[i]

        pos_rew_flag = (self.rewards > 0).bool()
        for i in range(1, self.num_frames):
            pos_rew_flag[i] = torch.logical_or(pos_rew_flag[i], torch.logical_and(self.memory_mask[i], pos_rew_flag[i - 1]))
        for i in reversed(range(self.num_frames - 1)):
            pos_rew_flag[i] = torch.logical_or(pos_rew_flag[i], torch.logical_and(self.not_dones[i], pos_rew_flag[i + 1]))

        exps = DictList()
        exps.obs = batch_tensor_obs_squeeze(self.obss)
        exps.actor_memory = self.actor_memories.transpose(0, 1).reshape(-1, *self.actor_memories.shape[2:])
        exps.critic_memory = self.critic_memories.transpose(0, 1).reshape(-1, *self.critic_memories.shape[2:])
        exps.memory_mask = self.memory_mask.transpose(0, 1).reshape(-1, 1)
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1, self.value_size)
        exps.returnn = self.returns.transpose(0, 1).reshape(-1, self.value_size)
        exps.mc_returnn = mc_returns.transpose(0, 1).reshape(-1)
        exps.pos_rew_flag = pos_rew_flag.transpose(0, 1).reshape(-1)

        return exps

    def update_parameters(self, exps):
        for _ in range(self.epochs):
            log_losses = []
            log_policy_losses = []
            log_value_losses = []
            log_entropy_losses = []
            log_num_valids = []
            log_grad_norms = []

            for inds in self._get_mini_batches_recurrence_starting_indexes():
                mini_batch_total_loss = 0
                mini_batch_policy_loss = 0
                mini_batch_value_loss = 0
                mini_batch_entropy_loss = 0
                mini_batch_num_valid = 0
                mini_batch_loss = 0

                sb = exps[inds]
                dist, actor_memory = self.actor_model(sb.obs, sb.actor_memory * sb.memory_mask)
                critic_value, critic_memory = self.critic_model(sb.obs, sb.critic_memory * sb.memory_mask)

                mask = torch.where(
                    sb.mc_returnn - critic_value[..., 0] > 0,
                    torch.ones_like(sb.mc_returnn),
                    torch.zeros_like(sb.mc_returnn),
                )
                mask = mask * sb.pos_rew_flag
                num_valid = torch.sum(mask)

                nlogp = -dist.log_prob(sb.action)
                nlogp_clipped = nlogp + (torch.clamp(nlogp, max=self.max_nlogp) - nlogp).detach()
                adv = torch.clamp(sb.mc_returnn - critic_value[..., 0], min=0, max=self.clip).detach()
                policy_loss = (nlogp_clipped * adv * mask).sum() / (num_valid + self.eps)

                entropy = (dist.entropy() * mask).sum() / (num_valid + self.eps)

                delta = torch.clamp(critic_value[..., 0] - sb.mc_returnn, min=-self.clip, max=0) * mask
                delta = delta.detach()
                value_loss = (critic_value[..., 0] * delta * mask).sum() / (num_valid + self.eps)

                loss = policy_loss - entropy * 0.01 + value_loss * 0.5 * 0.01
                loss = loss * self.dual_coef

                mini_batch_total_loss += loss.item()
                mini_batch_policy_loss += policy_loss.item()
                mini_batch_value_loss += value_loss.item()
                mini_batch_entropy_loss += entropy.item()
                mini_batch_num_valid += num_valid.item()
                mini_batch_loss += loss

                self.optimizer.zero_grad()
                mini_batch_loss.backward()
                grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.parameters) ** 0.5
                torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
                self.optimizer.step()

                log_losses.append(mini_batch_total_loss)
                log_policy_losses.append(mini_batch_policy_loss)
                log_value_losses.append(mini_batch_value_loss)
                log_entropy_losses.append(mini_batch_entropy_loss)
                log_num_valids.append(mini_batch_num_valid)
                log_grad_norms.append(grad_norm)

        logs = {
            "loss": np.mean(log_losses),
            "policy_loss": np.mean(log_policy_losses),
            "value_loss": np.mean(log_value_losses),
            "entropy_loss": np.mean(log_entropy_losses),
            "num_valid": np.sum(log_num_valids),
            "grad_norm": np.mean(log_grad_norms),
        }
        return logs

    def _get_mini_batches_recurrence_starting_indexes(self):
        starting_indexes = np.arange(0, self.batch_size, 1)
        starting_indexes = np.random.permutation(starting_indexes)

        mini_batch_size = self.batch_size // self.num_mini_batches
        num_indexes_per_batch = mini_batch_size
        batches_starting_indexes = [
            starting_indexes[i : i + num_indexes_per_batch] for i in range(0, len(starting_indexes), num_indexes_per_batch)
        ]

        return batches_starting_indexes
