from collections import OrderedDict

import os
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import RTorchTrainer
from rlkit.samplers.data_collector.path_collector import MdpPathCollector

from torch import autograd

class CQLTrainer(RTorchTrainer):
    def __init__(
            self,
            env,
            exp_name,
            policy,
            behavior_policy,
            qf1,
            qf2,
            qf_bp,
            target_qf1,
            target_qf2,
            # vf,
            # target_vf,

            discount=0.99,
            reward_scale=1.0,
            clamp_min=-5.0,
            clamp_max=5.0,

            policy_lr=1e-3,
            qf_lr=1e-3,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,
            policy_eval_start=0,
            num_qs=2,

            # CQL
            ver=0,
            min_q_version=3,
            temp=1.0,
            ratio_temp=1.0,
            min_q_weight=1.0,

            ## sort of backup
            max_q_backup=False,
            deterministic_backup=True,
            num_random=10,
            with_lagrange=False,
            lagrange_thresh=0.0,

            bellman_weight=False,
            use_ratio=False,
    ):
        super().__init__()
        self.env = env
        self.exp_name = exp_name
        self.policy = policy
        self.behavior_policy = behavior_policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.qf_bp = qf_bp
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        # self.vf = vf
        # self.target_vf = target_vf
        self.soft_target_tau = soft_target_tau
        self.bellman_weight = bellman_weight
        self.use_ratio = use_ratio
        print("self.bellman_weight: ", bellman_weight)
        print("self.use_ratio: ", use_ratio)

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.with_lagrange = with_lagrange
        print("self.with_lagrange: ", self.with_lagrange)
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh
            print("lagrange_thresh: ", lagrange_thresh)
            self.log_alpha_prime = ptu.zeros(1, requires_grad=True)
            self.alpha_prime_optimizer = optimizer_class(
                [self.log_alpha_prime],
                lr=qf_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        # self.vf_optimizer = optimizer_class(
        #     self.vf.parameters(),
        #     lr=qf_lr,
        # )

        self.discount = discount
        print("self.discount: ", discount)
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self.eval_wandb = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.policy_eval_start = policy_eval_start

        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self._num_policy_steps = 1

        self.num_qs = num_qs

        ## min Q
        self.temp = temp
        self.ratio_temp = ratio_temp
        self.min_q_version = min_q_version
        self.min_q_weight = min_q_weight
        self.ver = ver
        print("self.ver: ", ver)
        print("self.min_q_version: ", min_q_version)
        print("self.min_q_weight: ", min_q_weight)

        self.softmax = torch.nn.Softmax(dim=1)
        self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20)

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup
        self.num_random = num_random

        # For implementation on the
        self.discrete = False

        self.clamp_max = clamp_max
        self.clamp_min = clamp_min
        print("self.clamp_max: \t", clamp_max)
        print("self.clamp_min: \t", clamp_min)

        self.path = f'{self.exp_name}'
        if not os.path.exists(self.path):
            os.makedirs(self.path)

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def train_from_torch(self, batch):
        self._current_epoch += 1

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        soft_weight = batch['soft_weight']
        cls_len = batch['cls_len']

        if self.ver==1:
            soft_weight = soft_weight.clamp(min=0.1)
        elif self.ver==2:
            soft_weight = soft_weight * cls_len
        elif self.ver==3:
            soft_weight = (soft_weight * cls_len).clamp(min=0.1)


        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)

        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=False, return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs, reparameterize=False, return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[
                0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[
                0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        if self.bellman_weight:
            qf1_loss = (((q1_pred - q_target) ** 2) * soft_weight).mean()
            qf2_loss = (((q2_pred - q_target) ** 2) * soft_weight).mean()
        else:
            qf1_loss = self.qf_criterion(q1_pred, q_target)
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        ## add CQL
        beta_actions, curr_log_betas = self._get_policy_actions(obs, num_actions=self.num_random,
                                                                network=self.behavior_policy)
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random,
                                                                     network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random,
                                                                        network=self.policy)

        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1,
                                                                                                                  1)  # .cuda()
        if new_curr_actions_tensor.is_cuda:
            random_actions_tensor = random_actions_tensor.cuda()

        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
        q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)

        random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])

        # torch.Size([B, N])
        if self.use_ratio == True:
            with torch.no_grad():
                obs_temp = obs.unsqueeze(1).repeat(1, self.num_random, 1).view(obs.shape[0] * self.num_random,
                                                                              obs.shape[1])
                # 지금 ratio가 beta로 구해지고 있는데, beta' 필요없는지 (지금은 1-beta/rho)
                pi_log_betas = self.behavior_policy.log_prob(obs_temp, curr_actions_tensor).view(obs.shape[0], -1)
                npi_log_betas = self.behavior_policy.log_prob(obs_temp, new_curr_actions_tensor).view(obs.shape[0], -1)
                rand_log_betas = self.behavior_policy.log_prob(obs_temp, random_actions_tensor).view(obs.shape[0], -1)
                beta_log_betas = curr_log_betas.view(obs.shape[0], -1)

                batch_log_betas = self.behavior_policy.log_prob(obs, actions).view(obs.shape[0], -1)
                log_rho = torch.ones_like(pi_log_betas) * random_density

                # temp 조절을 통해 log_ratio을 완화
                log_batch_ratio = (batch_log_betas - random_density).clamp(min=-20., max=0.0).detach()
                log_beta_ratio = (beta_log_betas - log_rho).clamp(min=-20., max=0.0).detach()
                log_pi_ratio = (pi_log_betas - log_rho).clamp(min=-20., max=0.0).detach()
                log_npi_ratio = (npi_log_betas - log_rho).clamp(min=-20., max=0.0).detach()
                log_rand_ratio = (rand_log_betas - log_rho).clamp(min=-20., max=0.0).detach()

                # (1-beta/rho)+
                # torch.Size([512, 50])
                batch_ratio = (1 - torch.exp(log_batch_ratio)).detach()
                beta_ratio = (1 - torch.exp(log_beta_ratio)).detach()
                pi_ratio = (1 - torch.exp(log_pi_ratio)).detach()
                npi_ratio = (1 - torch.exp(log_npi_ratio)).detach()
                rand_ratio = (1 - torch.exp(log_rand_ratio)).detach()

                batch_weight = (batch_ratio ** self.ratio_temp).clamp(min=1e-30).detach()
                beta_weight = (beta_ratio ** self.ratio_temp).clamp(min=1e-30).detach()
                pi_weight = (pi_ratio ** self.ratio_temp).clamp(min=1e-30).detach()
                npi_weight = (npi_ratio ** self.ratio_temp).clamp(min=1e-30).detach()
                rand_weight = (rand_ratio ** self.ratio_temp).clamp(min=1e-30).detach()
        else:
            batch_weight = torch.ones_like(rewards)
            beta_weight = torch.ones_like(rewards)
            pi_weight = torch.ones_like(rewards)
            npi_weight = torch.ones_like(rewards)
            rand_weight = torch.ones_like(rewards)

        if self.min_q_version == 3:
            # importance sammpled version
            cat_q1 = torch.cat(
                [torch.log(rand_weight).view(-1, self.num_random, 1) + q1_rand - random_density,
                 torch.log(npi_weight).view(-1, self.num_random, 1) + q1_next_actions - new_log_pis.detach(),
                 torch.log(pi_weight).view(-1, self.num_random, 1) + q1_curr_actions - curr_log_pis.detach()], 1
            )

            cat_q2 = torch.cat(
                [torch.log(rand_weight + 1e-30).view(-1, self.num_random, 1) + q2_rand - random_density,
                 torch.log(npi_weight + 1e-30).view(-1, self.num_random, 1) + q2_next_actions - new_log_pis.detach(),
                 torch.log(pi_weight + 1e-30).view(-1, self.num_random,1) + q2_curr_actions - curr_log_pis.detach()], 1
            )

            nw_cat_q1 = torch.cat(
                [q1_rand - random_density,
                 q1_next_actions - new_log_pis.detach(),
                 q1_curr_actions - curr_log_pis.detach(), ], dim=1)

            nw_cat_q2 = torch.cat(
                [q2_rand - random_density,
                 q2_next_actions - new_log_pis.detach(),
                 q2_curr_actions - curr_log_pis.detach(), ], dim=1)

        elif self.min_q_version == 2:
            cat_q1 = torch.cat(
                [torch.log(rand_weight).view(-1, self.num_random, 1) + q1_rand,
                 torch.log(npi_weight).view(-1, self.num_random, 1) + q1_next_actions,
                 torch.log(pi_weight).view(-1, self.num_random, 1) + q1_curr_actions,
                 torch.log(batch_ratio).view(-1, 1, 1) + q1_pred.unsqueeze(1)], 1
            )
            cat_q2 = torch.cat(
                [torch.log(rand_weight).view(-1, self.num_random, 1) + q2_rand,
                 torch.log(npi_weight).view(-1, self.num_random, 1) + q2_next_actions,
                 torch.log(pi_weight).view(-1, self.num_random, 1) + q2_curr_actions,
                 torch.log(batch_ratio).view(-1, 1, 1) + q2_pred.unsqueeze(1)], 1
            )

            nw_cat_q1 = torch.cat(
                [q1_rand,
                 q1_next_actions,
                 q1_curr_actions,
                 q1_pred.unsqueeze(1)], dim=1)
            nw_cat_q2 = torch.cat(
                [q2_rand,
                 q2_next_actions,
                 q2_curr_actions,
                 q2_pred.unsqueeze(1)], dim=1)
        else:
            raise NotImplementedError

        logsum_q1 = torch.logsumexp(cat_q1, dim=1, )
        logsum_q2 = torch.logsumexp(cat_q2, dim=1, )
        with torch.no_grad():
            nw_logsum_q1 = torch.logsumexp(nw_cat_q1, dim=1, )
            nw_logsum_q2 = torch.logsumexp(nw_cat_q2, dim=1, )

            logsum_q1_diff = logsum_q1 - nw_logsum_q1
            logsum_q2_diff = logsum_q2 - nw_logsum_q2

            logsum_ratio_q1 = torch.exp(logsum_q1_diff).detach()  # <= 1
            logsum_ratio_q2 = torch.exp(logsum_q2_diff).detach()  # <= 1

        min_qf1_loss = soft_weight * (logsum_ratio_q1 * logsum_q1 - (batch_weight * q1_pred).mean(dim=-1, keepdim=True))
        min_qf2_loss = soft_weight * (logsum_ratio_q2 * logsum_q2 - (batch_weight * q2_pred).mean(dim=-1, keepdim=True))

        # - (E_pi [(1-beta/rho)+] - E_beta [(1-beta/rho)+]) E_beta [Q]
        ratio_diff = pi_weight.mean(dim=1, keepdim=True) - beta_weight.mean(dim=1, keepdim=True)

        s3_min_qf1_loss = -(soft_weight * (ratio_diff * q1_pred.mean(dim=1, keepdim=True))).mean() * self.min_q_weight
        s3_min_qf2_loss = -(soft_weight * (ratio_diff * q2_pred.mean(dim=1, keepdim=True))).mean() * self.min_q_weight

        min_qf1_loss = min_qf1_loss.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss.mean() * self.min_q_weight

        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        # if self.bellman_weight == False:
        #     ### add compensate loss
        #     obs_temp = obs.unsqueeze(1).repeat(1, self.num_random, 1).view(obs.shape[0] * self.num_random, obs.shape[1])
        #
        #     with torch.no_grad():
        #         beta_q1_pred = self.qf1(obs_temp, beta_actions).view(obs.shape[0], self.num_random, 1)
        #         beta_q2_pred = self.qf2(obs_temp, beta_actions).view(obs.shape[0], self.num_random, 1)
        #
        #         target_beta_values = torch.min(beta_q1_pred, beta_q2_pred).view(obs.shape[0], self.num_random, -1)
        #         target_beta_values = torch.logsumexp(target_beta_values, dim=1, keepdim=False) - np.log(self.num_random)
        #
        #         target_pol_values = torch.min(self.target_qf1(obs_temp, curr_actions_tensor), self.target_qf2(obs_temp, curr_actions_tensor)).view(obs.shape[0], self.num_random, -1)
        #         target_pol_values = torch.logsumexp(torch.log(pi_weight).view(-1, self.num_random, 1) + target_pol_values, dim=1, keepdim=False) - np.log(self.num_random)
        #
        #         bp_weight = (soft_weight.unsqueeze(1) * beta_weight.view(-1, self.num_random, 1)).clamp(min=1e-30)
        #         target_bp_values = torch.min(self.target_qf1(obs_temp, beta_actions), self.target_qf2(obs_temp, beta_actions)).view(obs.shape[0], self.num_random, -1)
        #         target_bp_values = torch.logsumexp(2 * target_bp_values, dim=1, keepdim=False)
        #
        #         nc_pi_log_diff = target_pol_values - target_beta_values
        #         nc_bp_log_diff = target_bp_values - 2 * target_beta_values
        #
        #         pi_log_diff = torch.clamp(nc_pi_log_diff, min=self.clamp_min, max=self.clamp_max)
        #         bp_log_diff = torch.clamp(nc_bp_log_diff, min=self.clamp_min, max=self.clamp_max)
        #
        #         Z_pi = torch.exp(pi_log_diff) ** 0.5
        #         Z_prime = torch.exp(bp_log_diff) ** 0.5
        #
        #         # 클립 가능 -0.1~0.1 같이 많이 안흔들리도록
        #         # 성능 불안정한게 z_diff가 불안정적이여서 그럴수도
        #         # new_Z_diff = (Z_pi - Z_prime).detach() * (1 / cls_len)
        #         new_Z_diff = (Z_pi - Z_prime).detach() * (1 / cls_len)
        #
        #     if (self._current_epoch // 1000) % 5 == 0:
        #         curr= (self._current_epoch // 1000)
        #         np.save(f'{self.path}/new_Z_diff_{curr}.npy', ptu.get_numpy(new_Z_diff))
        #
        #     #     np.save(f'{self.path}/nc_pi_log_diff_{curr}.npy', ptu.get_numpy(nc_pi_log_diff))
        #     #     np.save(f'{self.path}/nc_prime_log_diff_{curr}.npy', ptu.get_numpy(nc_bp_log_diff))
        #     #     np.save(f'{self.path}/pi_log_diff_{curr}.npy', ptu.get_numpy(pi_log_diff))
        #     #     np.save(f'{self.path}/bp_log_diff_{curr}.npy', ptu.get_numpy(bp_log_diff))
        #     #     np.save(f'{self.path}/Z_pi_{curr}.npy', ptu.get_numpy(Z_pi))
        #     #     np.save(f'{self.path}/Z_prime_{curr}.npy', ptu.get_numpy(Z_prime))
        #
        #     normalize_qf1_loss = -(z_diffs * q1_pred).mean() * self.min_q_weight
        #     normalize_qf2_loss = -(z_diffs * q2_pred).mean() * self.min_q_weight
        #
        #     qf1_loss = qf1_loss + min_qf1_loss + normalize_qf1_loss
        #     qf2_loss = qf2_loss + min_qf2_loss + normalize_qf2_loss
        else:
            qf1_loss = qf1_loss + min_qf1_loss + s3_min_qf1_loss
            qf2_loss = qf2_loss + min_qf2_loss + s3_min_qf2_loss

        """
        Update networks
        """
        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        # self.vf_optimizer.zero_grad()
        # vf_loss.backward(retain_graph=True)
        # self.vf_optimizer.step()

        """
        Soft Updates
        """
        ptu.soft_update_from_to(
            self.qf1, self.target_qf1, self.soft_target_tau
        )
        if self.num_qs > 1:
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )
        # ptu.soft_update_from_to(
        #     self.vf, self.target_vf, self.soft_target_tau
        # )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            # self.eval_statistics['Target Difference'] = np.mean(ptu.get_numpy(
            #     target_diff
            # ))
            # self.eval_statistics['VF Pred'] = np.mean(ptu.get_numpy(
            #     vf_pred
            # ))
            # self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(
            #     vf_loss
            # ))
            # self.eval_statistics.update(create_stats_ordered_dict(
            #     'advs',
            #     ptu.get_numpy(advs),
            # ))
            # self.eval_statistics.update(create_stats_ordered_dict(
            #     'new_advs',
            #     ptu.get_numpy(new_advs),
            # ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'cls_len',
                ptu.get_numpy(cls_len),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Soft Weight',
                ptu.get_numpy(soft_weight),
            ))

            # if self.bellman_weight == False:
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'target_pol_values',
            #         ptu.get_numpy(target_pol_values),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'target_beta_values',
            #         ptu.get_numpy(target_beta_values),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'nc_pi_log_diff',
            #         ptu.get_numpy(nc_pi_log_diff),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'nc_bp_log_diff',
            #         ptu.get_numpy(nc_bp_log_diff),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'target_bp_values',
            #         ptu.get_numpy(target_bp_values),
            #     ))
            #
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'pi_log_diff',
            #         ptu.get_numpy(pi_log_diff),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'bp_log_diff',
            #         ptu.get_numpy(bp_log_diff),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'pi_log_diff',
            #         ptu.get_numpy(pi_log_diff),
            #     ))
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'new_Z_diff',
            #         ptu.get_numpy(new_Z_diff),
            #     ))

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['s3 min QF1 Loss'] = np.mean(ptu.get_numpy(s3_min_qf1_loss))
            self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss))
            # if self.bellman_weight == False:
            #     self.eval_statistics['Norm QF1 Loss'] = np.mean(ptu.get_numpy(normalize_qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
                self.eval_statistics['s3 min QF2 Loss'] = np.mean(ptu.get_numpy(s3_min_qf2_loss))
                self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss))
                # if self.bellman_weight == False:
                #     self.eval_statistics['Norm QF2 Loss'] = np.mean(ptu.get_numpy(normalize_qf2_loss))

            self.eval_statistics.update(create_stats_ordered_dict(
                'Batch weight',
                ptu.get_numpy(batch_weight),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Beta weight',
                ptu.get_numpy(beta_weight),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Pi weight',
                ptu.get_numpy(pi_weight),
            ))
            # if self.bellman_weight==False:
            #     self.eval_statistics.update(create_stats_ordered_dict(
            #         'Bp weight',
            #         ptu.get_numpy(bp_weight),
            #     ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'NPi weight',
                ptu.get_numpy(npi_weight),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Rand weight',
                ptu.get_numpy(rand_weight),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'logsum_ratio',
                ptu.get_numpy(logsum_ratio_q1),
            ))

            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 in-distribution values',
                    ptu.get_numpy(q1_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 in-distribution values',
                    ptu.get_numpy(q2_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 random values',
                    ptu.get_numpy(q1_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 random values',
                    ptu.get_numpy(q2_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 next_actions values',
                    ptu.get_numpy(q1_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 next_actions values',
                    ptu.get_numpy(q2_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'actions',
                    ptu.get_numpy(actions)
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards)
                ))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            if self.num_qs > 1:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.with_lagrange:
                self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean()
                self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean()
                self.eval_statistics['threshold action gap'] = self.target_action_gap
                self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item()

            def get_all_state_overestim(paths):
                overestim = []
                path_over_info = {'path_overestim': [], 'gamm_return': [], 'q_val': []}

                for path in paths:
                    gamma_return = 0
                    for i in reversed(range(path["rewards"].size)):
                        gamma_return = path["rewards"][i] + self.discount * gamma_return * (1 - path["terminals"][i])
                        with torch.no_grad():
                            q_val = torch.stack([self.qf1(ptu.from_numpy(path["observations"][i]).unsqueeze(0),
                                                          ptu.from_numpy(path["actions"][i]).unsqueeze(0)),
                                                 self.qf2(ptu.from_numpy(path["observations"][i]).unsqueeze(0),
                                                          ptu.from_numpy(path["actions"][i]).unsqueeze(0))], 0)
                            q_val = torch.min(q_val, dim=0)[0].squeeze(0).cpu().numpy()

                        path_over_info["gamm_return"].append(gamma_return.item())
                        path_over_info['q_val'].append((q_val).item())
                        path_over_info["path_overestim"].append((q_val - gamma_return).item())

                    path_over_info["gamm_return"] = path_over_info["gamm_return"][::-1]
                    path_over_info["q_val"] = path_over_info["q_val"][::-1]
                    path_over_info["path_overestim"] = path_over_info["path_overestim"][::-1]

                    overestim.append(path_over_info)

                return np.array(overestim)

            self._reserve_path_collector = MdpPathCollector(
                env=self.env, policy=self.policy,
            )
            self._reserve_path_collector.update_policy(self.policy)

            # Sampling
            eval_paths = self._reserve_path_collector.collect_new_paths(
                max_path_length=1000,
                num_steps=1000,
                discard_incomplete_paths=True,
            )

            overestim_info = get_all_state_overestim(eval_paths)

            self.eval_statistics.update(
                create_stats_ordered_dict('Overestimation', overestim_info[0]['path_overestim']))

            if (self._current_epoch // 1000) % 5 == 0:
                curr= (self._current_epoch // 1000)
                np.save(f'{self.path}/overestim_info_{curr}.npy', overestim_info)

        self._n_train_steps_total += 1

        q_values = self.qf_bp(obs, actions)

        return ptu.get_numpy(q_values)

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        base_list = [
            self.policy,
            self.behavior_policy,
            self.qf1,
            self.qf2,
            self.qf_bp,
            self.target_qf1,
            self.target_qf2,
            # self.vf,
            # self.target_vf,
        ]
        return base_list

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.target_qf1,
            target_qf2=self.target_qf2,
            # vf=self.vf,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.qf1 = snapshot['qf1']
        self.qf2 = snapshot['qf2']
        # self.vf = snapshot['vf']
        self.target_qf1 = snapshot['target_qf1']
        self.target_qf2 = snapshot['target_qf2']

