import itertools

import numpy as np
import torch

from dcrl.algos.base import BaseAlgo


class PPOAlgo(BaseAlgo):
    def __init__(
        self,
        envs,
        actor_model,
        critic_model,
        device,
        num_frames=512,
        gamma=0.99,
        lr=3e-4,
        gae_lambda=0.95,
        entropy_coef=0.01,
        value_loss_coef=0.5,
        max_grad_norm=0.5,
        reshape_reward_fn=None,
        reshape_adv_fn=None,
        adam_eps=1e-5,
        clip_eps=0.2,
        epochs=4,
        num_mini_batches=16,
    ):

        super().__init__(
            envs,
            actor_model,
            critic_model,
            device,
            num_frames,
            gamma,
            lr,
            gae_lambda,
            entropy_coef,
            value_loss_coef,
            max_grad_norm,
            reshape_reward_fn,
            reshape_adv_fn,
        )

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.num_mini_batches = num_mini_batches

        self.parameters = itertools.chain(*(model.parameters() for model in [self.actor_model, self.critic_model]))
        self.optimizer = torch.optim.Adam(self.parameters, lr, eps=adam_eps)

    def update_parameters(self, exps):
        for _ in range(self.epochs):
            log_entropies = []
            log_values = []
            log_policy_losses = []
            log_value_losses = []
            log_grad_norms = []

            for inds in self._get_mini_batches_recurrence_starting_indexes():
                mini_batch_entropy = 0
                mini_batch_value = 0
                mini_batch_policy_loss = 0
                mini_batch_value_loss = 0
                mini_batch_loss = 0

                sb = exps[inds]

                dist, actor_memory = self.actor_model(sb.obs, sb.actor_memory * sb.memory_mask)
                value, critic_memory = self.critic_model(sb.obs, sb.critic_memory * sb.memory_mask)

                entropy = dist.entropy().mean()

                ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                surr1 = ratio * sb.advantage
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage
                policy_loss = -torch.min(surr1, surr2).mean()

                value_loss = (value - sb.returnn).pow(2).sum(dim=-1).mean()

                loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss

                mini_batch_entropy += entropy.item()
                mini_batch_value += value.mean().item()
                mini_batch_policy_loss += policy_loss.item()
                mini_batch_value_loss += value_loss.item()
                mini_batch_loss += loss

                self.optimizer.zero_grad()
                mini_batch_loss.backward()
                grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.parameters) ** 0.5
                torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
                self.optimizer.step()

                log_entropies.append(mini_batch_entropy)
                log_values.append(mini_batch_value)
                log_policy_losses.append(mini_batch_policy_loss)
                log_value_losses.append(mini_batch_value_loss)
                log_grad_norms.append(grad_norm)

        logs = {
            "entropy": np.mean(log_entropies),
            "value": np.mean(log_values),
            "policy_loss": np.mean(log_policy_losses),
            "value_loss": np.mean(log_value_losses),
            "grad_norm": np.mean(log_grad_norms),
        }
        return logs

    def _get_mini_batches_recurrence_starting_indexes(self):
        starting_indexes = np.arange(0, self.batch_size, 1)
        starting_indexes = np.random.permutation(starting_indexes)

        mini_batch_size = self.batch_size // self.num_mini_batches
        num_indexes_per_batch = mini_batch_size
        batches_starting_indexes = [
            starting_indexes[i : i + num_indexes_per_batch] for i in range(0, len(starting_indexes), num_indexes_per_batch)
        ]

        return batches_starting_indexes
