from collections import OrderedDict

import os
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import torch.nn.functional as F

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from rlkit.samplers.data_collector.path_collector import MdpPathCollector

import matplotlib.pyplot as plt
from tqdm import tqdm

class OneStepSimpleTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by SARSA
    """
    def __init__(
            self,
            env,
            exp_name,
            policy,
            beta_policy,
            beta_prime_policy,
            qf_pi,
            target_qf_pi,
            qf_prime,
            qf_beta,
            prob_net,
            vae,

            kl_reg='Inverse',
            n_actions=10,
            alpha=1.0,
            beta=1.0,
            discount=0.99,
            reward_scale=1.0,

            qf_lr=1e-4,
            policy_lr=1e-4,
            optimizer_class=optim.Adam,

            prob_ver=1,
            ver=1,
            cluster_idx_list=None,
            temp=1.0,
            prob_temp=1.0,
            std_scale=1.0,
            std=None,

            soft_target_tau=5e-3,
            target_update_period=2,

            add_entropy = False,
            num_q=1,
    ):
        super().__init__()
        self.env = env
        self.exp_name = exp_name
        self.policy = policy
        self.beta_policy = beta_policy
        self.beta_prime_policy = beta_prime_policy
        self.qf_pi = qf_pi
        self.target_qf_pi = target_qf_pi
        self.qf_prime = qf_prime
        self.qf_beta = qf_beta
        self.prob_net = prob_net
        self.vae = vae
        self.num_q = num_q

        self.kl_reg = kl_reg
        print('self.kl_reg: \t', self.kl_reg)

        self.n_actions = n_actions
        self.alpha = alpha
        self.beta = beta

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf_optimizer = [optimizer_class(
            self.qf_pi[i].parameters(),
            lr=qf_lr,
        ) for i in range(self.num_q)]
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.prob_optimizer = optimizer_class(
            self.prob_net.parameters(),
            lr=policy_lr,
        )
        self.vae_optimizer = optimizer_class(
            self.vae.parameters(),
            lr=policy_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.discrete = False
        self.ver = ver
        self.prob_ver = prob_ver
        print('self.prob_ver: \t', self.prob_ver)

        self.cluster_idx_list = cluster_idx_list
        self.temp = temp
        self.prob_temp = prob_temp
        self.std = std
        self.std_scale = std_scale

        self.add_entropy = add_entropy
        print('self.add_entropy: \t', self.add_entropy)

        self.target_entropy = -np.prod(self.env.action_space.shape).item()
        self.ent_log_alpha = ptu.zeros(1, requires_grad=True)
        self.alpha_optimizer = optimizer_class(
            [self.ent_log_alpha],
            lr=policy_lr,
        )
        self.path = f'./{self.exp_name}'
        if not os.path.exists(self.path):
            os.makedirs(self.path)

    def get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def calculate_sampling_prob(self, replay_buffer):

        self.qf_prime.cpu()
        self.prob_net.cpu()

        s = torch.from_numpy(replay_buffer._observations).float()#.to(ptu.device)
        a = torch.from_numpy(replay_buffer._actions).float()#.to(ptu.device)

        with torch.no_grad():

            q_vals = self.qf_prime(s, a) / self.temp
            clamped_qvals = torch.clamp(((q_vals - torch.max(q_vals)) - self.prob_net(s)), min=-5.0, max=5.0)
            prob = torch.exp(clamped_qvals)

            prob = prob ** (1 / self.prob_temp)
            prob = prob / (prob.sum() + 1e-10)

        del s, a

        self.qf_prime.to(ptu.device)
        self.prob_net.to(ptu.device)

        return ptu.get_numpy(prob.squeeze(-1))

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)

        return preds

    def train_from_torch(self, batch):
        obs = batch['observations']
        act = batch['actions']
        next_obs = batch['next_observations']
        rewards = batch['rewards']
        terminals = batch['terminals']

        """
        VAE Loss
        """

        recon, mean, std = self.vae(obs, act)
        recon_loss = self.qf_criterion(recon, act)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()

        """
        QF Loss
        """

        q_pred = [self.qf_pi[i](obs, act) for i in range(self.num_q)]

        next_actions, _, _, _, *_ = self.policy(
            next_obs, reparameterize=False, return_log_prob=True,
        )

        target_q_values = torch.stack([self.target_qf_pi[i](next_obs, next_actions) for i in range(self.num_q)],
                                 dim=0)
        target_q_values = torch.min(target_q_values, dim=0)[0]
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values

        qf_loss = [self.qf_criterion(q_pred[i], q_target.detach()) for i in range(self.num_q)]

        for i in range(self.num_q):
            self.qf_optimizer[i].zero_grad()
            qf_loss[i].backward()
            self.qf_optimizer[i].step()

        """
        Policy and Alpha Loss
        """

        new_obs_actions, policy_mean, policy_log_std, new_log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        obs_stack = torch.unsqueeze(obs, 1).repeat(1, self.n_actions, 1).reshape((-1, obs.shape[1]))
        new_obs_actions_stack, _, _, log_pi_stack, _, std, *_ = self.policy(obs_stack, reparameterize=True,
                                                                            return_log_prob=True, )
        log_pi = torch.mean(log_pi_stack.reshape((-1, self.n_actions)), dim=1)

        log_beta_stack = self.beta_prime_policy.log_prob(obs_stack, new_obs_actions_stack)
        log_beta = torch.mean(log_beta_stack.reshape((-1, self.n_actions)), dim=1)

        kl = (log_pi - log_beta).mean()

        # qf_prime는 beta prime으로 뽑는데, q_val은 약한 kl을 가진 policy로 뽑으면 overestimation 발생 안하나?
        # 약한 kl을 가지고도 어느정도의 performance가 나온다는 건 batch action에 noise가 많이 끼여있어서
        # batch 액션을 그대로 따라하면 안되고 어느정도 variation이 필요하다는 이야기? -> 근데 beta prime은 s, a 기반
        if self.ver == 2:
            q_vals = torch.stack([self.target_qf_pi[i](obs, new_obs_actions) for i in range(self.num_q)],
                                     dim=0)
            q_val = torch.min(q_vals, dim=0)[0]
        else:
            q_val = self.qf_prime(obs, new_obs_actions)

        if self.add_entropy:
            alpha_loss = -(self.ent_log_alpha * (-log_pi - self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            ent_alpha = self.ent_log_alpha.exp()

            policy_loss = ent_alpha * kl - q_val.mean()
        else:
            policy_loss = self.alpha * kl - q_val.mean()

        # For eval
        with torch.no_grad():
            q_pi = self.qf_pi[0](obs, new_obs_actions).view(-1, 1)
            q_pol = self.qf_prime(obs, new_obs_actions).view(-1, 1)
            q_beta = self.qf_beta(obs, new_obs_actions).view(-1, 1)
            q_diff = q_pol - q_beta

            bp_actions, _, _, log_bp, _, std, *_ = self.beta_prime_policy(obs, reparameterize=True,
                                                                                      return_log_prob=True, )

            log_b, _ = self.beta_policy.log_prob(obs, bp_actions, return_std=True)

            bp_kl = (log_bp - log_b).mean()
            p_kl = (new_log_pi - log_b).mean()

        """
        Probability network Loss
        """

        with torch.no_grad():
            if self._n_train_steps_total % 10000:
                N=50
                beta_actions, *_ = self.vae.decode_multiple(obs, num_decode=N)
                beta_actions = beta_actions.view(act.shape[0] * N, act.shape[-1])
            else:
                N = 1
                beta_actions = act

            q_betas = self._get_tensor_values(obs, beta_actions, network=self.qf_prime).squeeze(2) / self.temp
            max_q_betas = torch.max(q_betas)

            expected_prob = torch.logsumexp(q_betas - max_q_betas, dim=1, keepdim=True) - np.log(N)

        prob_net_value = self.prob_net(obs)
        prob_loss = F.mse_loss(expected_prob, prob_net_value)

        """
        Update networks
        """
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.prob_optimizer.zero_grad()
        prob_loss.backward()
        self.prob_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            for i in range(self.num_q):
                ptu.soft_update_from_to(
                    self.qf_pi[i], self.target_qf_pi[i], self.soft_target_tau
                )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            if self._n_train_steps_total % 1000 == 0 :
                with torch.no_grad():

                    q_vals = self.qf_prime(obs, act) / self.temp
                    max_q_vals = torch.max(q_vals)
                    mq = q_vals - max_q_vals

                    prob_p = self.prob_net(obs)

                    mmq = mq - prob_p

                    bp_act, *_ = self.beta_prime_policy(obs, reparameterize=True, return_log_prob=True, )

                    bq_vals = self.qf_prime(obs, bp_act) / self.temp
                    bmax_q_vals = torch.max(bq_vals)
                    bmq = bq_vals - bmax_q_vals

                    bprob_p = self.prob_net(obs)

                    bmmq = bmq - bprob_p

                    beta_mean = ptu.get_numpy(mq.mean())
                    beta_min = ptu.get_numpy(mq.min())
                    beta_max = ptu.get_numpy(mq.max())

                    plt.xlabel("Batch Number")
                    plt.ylabel("mq")
                    plt.plot(ptu.get_numpy(mq), label='mean:%.3f, min:%.3f, max:%.3f' % (beta_mean, beta_min, beta_max))
                    plt.legend()
                    plt.savefig(f'{self.path}/mq_{self._n_train_steps_total//10000}.png')
                    plt.close()

                    beta_mean = ptu.get_numpy(prob_p.mean())
                    beta_min = ptu.get_numpy(prob_p.min())
                    beta_max = ptu.get_numpy(prob_p.max())

                    plt.xlabel("Batch Number")
                    plt.ylabel("prob_p")
                    plt.plot(ptu.get_numpy(prob_p), label='mean:%.3f, min:%.3f, max:%.3f' % (beta_mean, beta_min, beta_max))
                    plt.legend()
                    plt.savefig(f'{self.path}/prob_p_{self._n_train_steps_total//10000}.png')
                    plt.close()

                    beta_mean = ptu.get_numpy(mmq.mean())
                    beta_min = ptu.get_numpy(mmq.min())
                    beta_max = ptu.get_numpy(mmq.max())

                    plt.xlabel("Batch Number")
                    plt.ylabel("mmq")
                    plt.plot(ptu.get_numpy(mmq), label='mean:%.3f, min:%.3f, max:%.3f' % (beta_mean, beta_min, beta_max))
                    plt.legend()
                    plt.savefig(f'{self.path}/mmq_{self._n_train_steps_total//10000}.png')
                    plt.close()

                    beta_mean = ptu.get_numpy(bmmq.mean())
                    beta_min = ptu.get_numpy(bmmq.min())
                    beta_max = ptu.get_numpy(bmmq.max())

                    plt.xlabel("Batch Number")
                    plt.ylabel("bmmq")
                    plt.plot(ptu.get_numpy(bmmq), label='mean:%.3f, min:%.3f, max:%.3f' % (beta_mean, beta_min, beta_max))
                    plt.legend()
                    plt.savefig(f'{self.path}/bmmq_{self._n_train_steps_total//10000}.png')
                    plt.close()

                    clamped_mmq = torch.clamp(mmq, min=-5.0, max=5.0)

                    exp_p = torch.exp(clamped_mmq)
                    temped_expp = exp_p ** (1/self.prob_temp)
                    prob = temped_expp / (temped_expp.sum() + 1e-10)

            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_q_vals',
                ptu.get_numpy(q_vals)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_mq',
                ptu.get_numpy(mq)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_prob_net',
                ptu.get_numpy(prob_p)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_mmq',
                ptu.get_numpy(mmq)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_clamped_mmq',
                ptu.get_numpy(clamped_mmq)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_exp_p',
                ptu.get_numpy(exp_p)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_temped_expp',
                ptu.get_numpy(temped_expp)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_clamped_mmq',
                ptu.get_numpy(clamped_mmq)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_prob',
                ptu.get_numpy(prob)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Expected Prob',
                ptu.get_numpy(expected_prob)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Prob Net',
                ptu.get_numpy(prob_net_value)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'sampling_prob',
                ptu.get_numpy(prob)
            ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'Q values',
                ptu.get_numpy(q_val)
            ))
            self.eval_statistics['Target Entropy'] = np.mean(self.target_entropy)
            self.eval_statistics['Entropy'] = np.mean(ptu.get_numpy(-log_pi))

            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy std',
                ptu.get_numpy(std)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q pi',
                ptu.get_numpy(q_pi)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q prime',
                ptu.get_numpy(q_pol)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q beta',
                ptu.get_numpy(q_beta)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Difference',
                ptu.get_numpy(q_diff)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log pi',
                ptu.get_numpy(log_pi)
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log beta',
                ptu.get_numpy(log_beta)
            ))
            self.eval_statistics['KL'] = np.mean(ptu.get_numpy(kl))
            self.eval_statistics['Beta prime KL'] = np.mean(ptu.get_numpy(bp_kl))
            self.eval_statistics['Policy KL'] = np.mean(ptu.get_numpy(p_kl))
            self.eval_statistics['Prob Loss'] = np.mean(ptu.get_numpy(prob_loss))
            if self.add_entropy:
                self.eval_statistics['Entropy parameter'] = np.mean(ptu.get_numpy(ent_alpha))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(policy_loss))

            def get_all_state_overestim(paths, q_net=self.qf_prime):
                overestim = []
                path_over_info = {'path_overestim':[], 'gamm_return':[], 'q_val':[]}

                for path in paths:
                    gamma_return = 0
                    for i in reversed(range(path["rewards"].size)):
                        gamma_return = path["rewards"][i] + 0.99 * gamma_return * (1 - path["terminals"][i])
                        with torch.no_grad():
                            obs = ptu.from_numpy(path["observations"][i]).unsqueeze(0)
                            act = ptu.from_numpy(path["actions"][i]).unsqueeze(0)

                            q_val = q_net(obs, act).squeeze(0).cpu().numpy()

                        path_over_info["gamm_return"].append(gamma_return.item())
                        path_over_info['q_val'].append((q_val).item())
                        path_over_info["path_overestim"].append((q_val - gamma_return).item())

                    path_over_info["gamm_return"] = path_over_info["gamm_return"][::-1]
                    path_over_info["q_val"] = path_over_info["q_val"][::-1]
                    path_over_info["path_overestim"] = path_over_info["path_overestim"][::-1]

                    overestim.append(path_over_info)

                return np.array(overestim)

            def policy_with_timesteps(paths, policy_net=self.policy):
                pol_info = []
                path_pol_info = {'action_diff':[],'pol_mean':[], 'pol_log_std':[], 'pol_log_prob':[], 'pol_ent':[], 'pol_std':[]}

                with torch.no_grad():
                    for path in paths:
                        for i in reversed(range(path["rewards"].size)):
                            state = ptu.from_numpy(path["observations"][i]).unsqueeze(0)
                            act = ptu.from_numpy(path["actions"][i])

                            pol_act, pol_mean, pol_log_std, pol_log_prob,\
                            pol_ent, pol_std, *_ = policy_net(state, reparameterize=True, return_log_prob=True,)

                            pol_act = pol_act.squeeze(0)
                            pol_mean = pol_mean.squeeze(0)
                            pol_log_std = pol_log_std.squeeze(0)
                            pol_log_prob = pol_log_prob.squeeze(0)
                            pol_ent = pol_ent.squeeze(0)
                            pol_std = pol_std.squeeze(0)

                            path_pol_info["action_diff"].append(F.mse_loss(pol_act, act).item())
                            path_pol_info["pol_mean"].append(ptu.get_numpy(pol_mean))
                            path_pol_info["pol_log_std"].append(ptu.get_numpy(pol_log_std))
                            path_pol_info["pol_log_prob"].append(pol_log_prob.item())
                            path_pol_info["pol_ent"].append(pol_ent.item())
                            path_pol_info["pol_std"].append(ptu.get_numpy(pol_std))

                    pol_info.append(path_pol_info)

                return np.array(pol_info)

            self._reserve_path_collector = MdpPathCollector(
                env=self.env, policy=self.policy,
            )
            self._reserve_path_collector.update_policy(self.policy)

            # Sampling
            eval_paths = self._reserve_path_collector.collect_new_paths(
                max_path_length=1000,
                num_steps=1000,
                discard_incomplete_paths=True,
            )

            path_pol_info = policy_with_timesteps(eval_paths, self.policy)
            path_bp_info = policy_with_timesteps(eval_paths, self.beta_prime_policy)
            path_beta_info = policy_with_timesteps(eval_paths, self.beta_policy)
            prime_overestim_info = get_all_state_overestim(eval_paths, self.qf_prime)
            pol_overestim_info = get_all_state_overestim(eval_paths, self.qf_pi[0])

            self.eval_statistics.update(create_stats_ordered_dict('Overestimation', prime_overestim_info[0]['path_overestim']))

            if self._n_train_steps_total % 10000 == 0 :
                np.save(self.path + f'/path_pol_info_{self._n_train_steps_total//1000}th', path_pol_info)
                np.save(self.path + f'/path_bp_info_{self._n_train_steps_total//1000}th', path_bp_info)
                np.save(self.path + f'/path_beta_info_{self._n_train_steps_total//1000}th', path_beta_info)
                np.save(self.path + f'/prime_overestim_info_{self._n_train_steps_total//1000}th', prime_overestim_info)
                np.save(self.path + f'/pol_overestim_info_{self._n_train_steps_total // 1000}th', pol_overestim_info)

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.beta_prime_policy,
            self.qf_pi,
            self.target_qf_pi,
            self.qf_prime,
            self.qf_beta,
            self.prob_net,
            self.vae
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            beta_prime_policy=self.beta_prime_policy,
            qf_pi=self.qf_pi,
            target_qf_pi=self.target_qf_pi,
            qf_prime=self.qf_prime,
            prob_net=self.prob_net,
            vae=self.vae,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.beta_prime_policy = snapshot['beta_prime_policy']
        self.qf_pi = snapshot['qf_pi']
        self.target_qf_pi = snapshot['target_qf_pi']
        self.qf_prime = snapshot['qf_prime']
        self.prob_net = snapshot['prob_net']
        self.vae = snapshot['vae']
