from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


class OneStepSimpleTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by SARSA
    """
    def __init__(
            self,
            env,
            policy,
            policy_data,
            qf_data,

            kl_reg=True,
            n_actions=10,
            alpha=1.0,
            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-4,
            optimizer_class=optim.Adam,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.policy_data = policy_data
        self.qf_data = qf_data

        self.kl_reg = kl_reg
        print('self.kl_reg: \t', self.kl_reg)

        self.n_actions = n_actions
        self.alpha = alpha

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.discrete = False

    def train_from_torch(self, batch):
        obs = batch['observations']

        """
        Policy and Alpha Loss
        """

        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.kl_reg:
            obs_stack = torch.unsqueeze(obs, 1).repeat(1, self.n_actions, 1).reshape((-1, obs.shape[1]))
            new_obs_actions_stack, _, _, log_pi_stack, *_ = self.policy(obs_stack, reparameterize=True, return_log_prob=True,)
            log_pi = torch.mean(log_pi_stack.reshape((-1, self.n_actions)), dim=1)

            log_pi_data_stack = self.policy_data.log_prob(obs_stack, new_obs_actions_stack)
            log_pi_data = torch.mean(log_pi_data_stack.reshape((-1, self.n_actions)), dim=1)

            kl = (log_pi - log_pi_data).mean()
            policy_loss = self.alpha * kl - self.qf_data(obs, new_obs_actions).mean()
        else:
            policy_loss = -1 * self.qf_data(obs, new_obs_actions).mean()

        """
        Update networks
        """
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(policy_loss))
            if self.kl_reg:
                self.eval_statistics['KL'] = np.mean(ptu.get_numpy(kl))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.policy_data,
            self.qf_data,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            policy_data=self.policy_data,
            qf_data=self.qf_data,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.policy_data = snapshot['policy_data']
        self.qf_data = snapshot['qf_data']
