import abc
import copy
# Visualization
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
from rlkit.torch import pytorch_util as ptu

import gtimer as gt
from rlkit.core.rl_algorithm import BaseRLAlgorithm, BaseRLAlgorithm2
from rlkit.core.rl_algorithm import eval_util
from rlkit.data_management.replay_buffer import ReplayBuffer
from rlkit.samplers.data_collector import PathCollector
from rlkit.samplers.data_collector.path_collector import MdpPathCollector
import numpy as np
from rlkit.torch.core import np_to_pytorch_batch
from scipy import special

import torch
from tqdm import tqdm


def get_flat_params(model):
    params = []
    for param in model.parameters():
        # import ipdb; ipdb.set_trace()
        params.append(param.data.cpu().numpy().reshape(-1))
    return np.concatenate(params)


def set_flat_params(model, flat_params, trainable_only=True):
    idx = 0
    # import ipdb; ipdb.set_trace()
    for p in model.parameters():
        flat_shape = int(np.prod(list(p.data.shape)))
        flat_params_to_assign = flat_params[idx:idx + flat_shape]

        if len(p.data.shape):
            p.data = ptu.tensor(flat_params_to_assign.reshape(*p.data.shape))
        else:
            p.data = ptu.tensor(flat_params_to_assign[0])
        idx += flat_shape
    return model


class EvalAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

    def plot_visualized_data(self, array_plus, array_minus, base_val, fig_label='None'):
        """Plot two kinds of visualizations here:
           (1) Trend of loss_minus with respect to loss_plus
           (2) Histogram of different gradient directions
        """
        # Type (1)
        array_plus = array_plus - base_val
        array_minus = array_minus - base_val
        print(fig_label)
        fig, ax = plt.subplots()
        ax.scatter(array_minus, array_plus)
        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
            np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
        ]
        ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
        # import ipdb; ipdb.set_trace()
        # ax.set_aspect('equal')
        ax.set_xlim(lims)
        ax.set_ylim(lims)
        plt.ylabel('L (theta + alpha * d) - L(theta)')
        plt.xlabel('L (theta - alpha * d) - L(theta)')
        plt.title('Loss vs Loss %s' % fig_label)
        plt.savefig('plots_hopper_correct_online_3e-4_n10_viz_sac_again/type1_' + (fig_label) + '.png')

        # Type (2)
        plt.figure(figsize=(5, 4))
        plt.subplot(211)
        grad_projections = (array_plus - array_minus) * 0.5
        plt.hist(grad_projections, bins=50)
        plt.xlabel('Gradient Value')
        plt.ylabel('Count')
        plt.subplot(212)

        # Curvature
        curvature_projections = (array_plus + array_minus) * 0.5
        plt.hist(curvature_projections, bins=50)
        plt.xlabel('Curvature Value')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig('plots_hopper_correct_online_3e-4_n10_viz_sac_again/spectra_joined_' + (fig_label) + '.png')

    def _visualize(self, policy=False, q_function=False, num_dir=50, alpha=0.1, iter=None):
        assert policy or q_function, "Both are false, need something to visualize"
        # import ipdb; ipdb.set_trace()
        policy_weights = get_flat_params(self.trainer.policy)
        # qf1_weights = get_flat_params(self.trainer.qf1)
        # qf2_weights = get_flat_params(self.trainer.qf2)

        policy_dim = policy_weights.shape[0]
        # qf_dim = qf1_weights.shape[0]

        # Create clones to assign weights
        policy_clone = copy.deepcopy(self.trainer.policy)

        # Create arrays for storing data
        q1_plus_eval = []
        q1_minus_eval = []
        q2_plus_eval = []
        q2_minus_eval = []
        qmin_plus_eval = []
        qmin_minus_eval = []
        returns_plus_eval = []
        returns_minus_eval = []

        # Groundtruth policy params
        policy_eval_qf1 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf1)
        policy_eval_qf2 = self._eval_q_custom_policy(self.trainer.policy, self.trainer.qf2)
        policy_eval_q_min = min(policy_eval_qf1, policy_eval_qf2)
        policy_eval_returns = self.eval_policy_custom(self.trainer.policy)

        # These are the policy saddle point detection
        for idx in range(num_dir):
            random_dir = np.random.normal(size=(policy_dim))
            theta_plus = policy_weights + alpha * policy_dim
            theta_minus = policy_weights - alpha * policy_dim

            set_flat_params(policy_clone, theta_plus)
            q_plus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_plus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_plus_min = min(q_plus_1, q_plus_2)
            eval_return_plus = self.eval_policy_custom(policy_clone)

            set_flat_params(policy_clone, theta_minus)
            q_minus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_minus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_minus_min = min(q_minus_1, q_minus_2)
            eval_return_minus = self.eval_policy_custom(policy_clone)

            # Append to array
            q1_plus_eval.append(q_plus_1)
            q2_plus_eval.append(q_plus_2)
            q1_minus_eval.append(q_minus_1)
            q2_minus_eval.append(q_minus_2)
            qmin_plus_eval.append(q_plus_min)
            qmin_minus_eval.append(q_minus_min)
            returns_plus_eval.append(eval_return_plus)
            returns_minus_eval.append(eval_return_minus)

        # Now we visualize
        # import ipdb; ipdb.set_trace()

        q1_plus_eval = np.array(q1_plus_eval)
        q1_minus_eval = np.array(q1_minus_eval)
        q2_plus_eval = np.array(q2_plus_eval)
        q2_minus_eval = np.array(q2_minus_eval)
        qmin_plus_eval = np.array(qmin_plus_eval)
        qmin_minus_eval = np.array(qmin_minus_eval)
        returns_plus_eval = np.array(returns_plus_eval)
        returns_minus_eval = np.array(returns_minus_eval)

        self.plot_visualized_data(q1_plus_eval, q1_minus_eval, policy_eval_qf1,
                                  fig_label='q1_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(q2_plus_eval, q2_minus_eval, policy_eval_qf2,
                                  fig_label='q2_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(qmin_plus_eval, qmin_minus_eval, policy_eval_q_min,
                                  fig_label='qmin_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(returns_plus_eval, returns_minus_eval, policy_eval_returns,
                                  fig_label='returns_policy_params_iter_' + (str(iter)))

        del policy_clone


class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            # env_mean=0.0,
            # env_std=1.0,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        # self.env_mean = env_mean
        # self.env_std = env_std
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithmA(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            path=None,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        # self.env_mean = env_mean
        # self.env_std = env_std
        self.exp_point = exp_point
        self.path = path
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)

                if epoch % 20 == 0 and epoch != 0:
                    self.trainer.qf1.cpu()
                    self.trainer.policy.cpu()
                    self.trainer.target_qf1.cpu()
                    self.trainer.target_qf2.cpu()

                    s = torch.from_numpy(self.replay_buffer._observations).float()  # .to(ptu.device)
                    a = torch.from_numpy(self.replay_buffer._actions).float()  # .to(ptu.device)
                    ns = torch.from_numpy(self.replay_buffer._next_obs).float()  # .to(ptu.device)

                    with torch.no_grad():
                        q_vals = ptu.get_numpy(self.trainer.qf1(s, a))

                        next_actions_temp, _ = self.trainer._get_policy_actions(ns, num_actions=10,
                                                                                network=self.trainer.policy)
                        target_qf1_values = \
                            self.trainer._get_tensor_values(ns, next_actions_temp, network=self.trainer.target_qf1).max(
                                1)[0].view(-1,
                                           1)
                        target_qf2_values = \
                            self.trainer._get_tensor_values(ns, next_actions_temp, network=self.trainer.target_qf2).max(
                                1)[0].view(-1,
                                           1)
                        target_q_values = ptu.get_numpy(torch.min(target_qf1_values, target_qf2_values))

                    np.save(f'{self.path}/q_vals_{epoch}.npy', q_vals)
                    np.save(f'{self.path}/tar_q_vals_{epoch}.npy', target_q_values)

                    del s, a, ns

                    self.trainer.qf1.to(ptu.device)
                    self.trainer.policy.to(ptu.device)
                    self.trainer.target_qf1.to(ptu.device)
                    self.trainer.target_qf2.to(ptu.device)

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


# bp
class BatchRLAlgorithm2(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            inner_step=20,
            use_vae=False,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.inner_step = inner_step
        self.use_vae = use_vae
        self.prob = None

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size, prob=self.prob)
                    self.pi_trainer.train(train_data)

                if epoch % self.inner_step == 0 and epoch != 0:
                    self.prob = self.pi_trainer.calculate_sampling_prob(self.replay_buffer)

                    for i in tqdm(range(40000), desc="Train inner loop"):
                        if i < 20000:
                            train_data = self.replay_buffer.random_batch(
                                self.batch_size, prob=self.prob)
                        else:
                            train_data = self.replay_buffer.random_batch(self.batch_size)
                        self.beta_trainer.train(train_data)

                    self.beta_trainer._n_train_steps_total = 0

                    self.pi_trainer.behavior_policy.load_state_dict(self.beta_trainer.beta_prime_policy.state_dict())
                    if isinstance(self.beta_trainer.qf, list):
                        for i in range(self.beta_trainer.num_q):
                            self.pi_trainer.qf_prime[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                    else:
                        self.pi_trainer.qf_prime.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

            # import ipdb; ipdb.set_trace()
            ## After epoch visualize
            # if epoch % 50 == 0:
            #     self._visualize(policy=True, num_dir=300, alpha=0.05, iter=epoch)
            #     print ('Saved Plots ..... %d'.format(epoch))

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

    def plot_visualized_data(self, array_plus, array_minus, base_val, fig_label='None'):
        """Plot two kinds of visualizations here:
           (1) Trend of loss_minus with respect to loss_plus
           (2) Histogram of different gradient directions
        """
        # Type (1)
        array_plus = array_plus - base_val
        array_minus = array_minus - base_val
        print(fig_label)
        fig, ax = plt.subplots()
        ax.scatter(array_minus, array_plus)
        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
            np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
        ]
        ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
        # import ipdb; ipdb.set_trace()
        # ax.set_aspect('equal')
        ax.set_xlim(lims)
        ax.set_ylim(lims)
        plt.ylabel('L (theta + alpha * d) - L(theta)')
        plt.xlabel('L (theta - alpha * d) - L(theta)')
        plt.title('Loss vs Loss %s' % fig_label)
        plt.savefig('plots_hopper_correct_online_3e-4_n10_viz_sac_again/type1_' + (fig_label) + '.png')

        # Type (2)
        plt.figure(figsize=(5, 4))
        plt.subplot(211)
        grad_projections = (array_plus - array_minus) * 0.5
        plt.hist(grad_projections, bins=50)
        plt.xlabel('Gradient Value')
        plt.ylabel('Count')
        plt.subplot(212)

        # Curvature
        curvature_projections = (array_plus + array_minus) * 0.5
        plt.hist(curvature_projections, bins=50)
        plt.xlabel('Curvature Value')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig('plots_hopper_correct_online_3e-4_n10_viz_sac_again/spectra_joined_' + (fig_label) + '.png')

    def _visualize(self, policy=False, q_function=False, num_dir=50, alpha=0.1, iter=None):
        assert policy or q_function, "Both are false, need something to visualize"
        # import ipdb; ipdb.set_trace()
        policy_weights = get_flat_params(self.pi_trainer.policy)
        # qf1_weights = get_flat_params(self.trainer.qf1)
        # qf2_weights = get_flat_params(self.trainer.qf2)

        policy_dim = policy_weights.shape[0]
        # qf_dim = qf1_weights.shape[0]

        # Create clones to assign weights
        policy_clone = copy.deepcopy(self.pi_trainer.policy)

        # Create arrays for storing data
        q1_plus_eval = []
        q1_minus_eval = []
        q2_plus_eval = []
        q2_minus_eval = []
        qmin_plus_eval = []
        qmin_minus_eval = []
        returns_plus_eval = []
        returns_minus_eval = []

        # Groundtruth policy params
        policy_eval_qf1 = self._eval_q_custom_policy(self.pi_trainer.policy, self.pi_trainer.qf1)
        policy_eval_qf2 = self._eval_q_custom_policy(self.pi_trainer.policy, self.pi_trainer.qf2)
        policy_eval_q_min = min(policy_eval_qf1, policy_eval_qf2)
        policy_eval_returns = self.eval_policy_custom(self.pi_trainer.policy)

        # These are the policy saddle point detection
        for idx in range(num_dir):
            random_dir = np.random.normal(size=(policy_dim))
            theta_plus = policy_weights + alpha * policy_dim
            theta_minus = policy_weights - alpha * policy_dim

            set_flat_params(policy_clone, theta_plus)
            q_plus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_plus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_plus_min = min(q_plus_1, q_plus_2)
            eval_return_plus = self.eval_policy_custom(policy_clone)

            set_flat_params(policy_clone, theta_minus)
            q_minus_1 = self._eval_q_custom_policy(policy_clone, self.trainer.qf1)
            q_minus_2 = self._eval_q_custom_policy(policy_clone, self.trainer.qf2)
            q_minus_min = min(q_minus_1, q_minus_2)
            eval_return_minus = self.eval_policy_custom(policy_clone)

            # Append to array
            q1_plus_eval.append(q_plus_1)
            q2_plus_eval.append(q_plus_2)
            q1_minus_eval.append(q_minus_1)
            q2_minus_eval.append(q_minus_2)
            qmin_plus_eval.append(q_plus_min)
            qmin_minus_eval.append(q_minus_min)
            returns_plus_eval.append(eval_return_plus)
            returns_minus_eval.append(eval_return_minus)

        # Now we visualize
        # import ipdb; ipdb.set_trace()

        q1_plus_eval = np.array(q1_plus_eval)
        q1_minus_eval = np.array(q1_minus_eval)
        q2_plus_eval = np.array(q2_plus_eval)
        q2_minus_eval = np.array(q2_minus_eval)
        qmin_plus_eval = np.array(qmin_plus_eval)
        qmin_minus_eval = np.array(qmin_minus_eval)
        returns_plus_eval = np.array(returns_plus_eval)
        returns_minus_eval = np.array(returns_minus_eval)

        self.plot_visualized_data(q1_plus_eval, q1_minus_eval, policy_eval_qf1,
                                  fig_label='q1_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(q2_plus_eval, q2_minus_eval, policy_eval_qf2,
                                  fig_label='q2_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(qmin_plus_eval, qmin_minus_eval, policy_eval_q_min,
                                  fig_label='qmin_policy_params_iter_' + (str(iter)))
        self.plot_visualized_data(returns_plus_eval, returns_minus_eval, policy_eval_returns,
                                  fig_label='returns_policy_params_iter_' + (str(iter)))

        del policy_clone


# bc_cls
class BatchRLAlgorithm3(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.cluster_idx_list = cluster_idx_list
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)
                    train_data['index_set'] = index_set
                    self.trainer.train(train_data, self.replay_buffer)
                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


# cls
class BatchRLAlgorithm4(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set]
                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]], axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1

                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])  # / self.temp
                    train_data['soft_weight'] = np.exp(train_data['diff'])

                    q_curr = self.trainer.train(train_data)

                    self.replay_buffer._q_curr[index_set] = q_curr

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4b(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set]
                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]], axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1
                    train_data['cls_len'] = np.expand_dims(np.array(
                        [len(self.replay_buffer._q_curr[self.cluster_idx_list[idx]]) for idx in index_set]), 1)

                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch']) / self.temp
                    train_data['soft_weight'] = np.exp(train_data['diff'])

                    q_curr = self.trainer.train(train_data)

                    self.replay_buffer._q_curr[index_set] = q_curr

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4bt(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set]
                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]], axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1

                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])  # / self.temp
                    train_data['soft_weight'] = np.exp(train_data['diff'])

                    self.trainer.train(train_data)

                if epoch % 5 == 0:
                    self.trainer.calculate_q_vals(self.replay_buffer)

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4sbp(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.update_step = update_step
        self.q_tau = q_tau
        self.exp_point = exp_point
        self.prob = None

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set]/self.temp
                    # B x N x 1 -> B x 1
                    # train_data['logsum_batch'] = np.array([special.logsumexp(
                    #     self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]], axis=0) for i in
                    #     range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]]/self.temp, axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])

                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array([len(self.cluster_idx_list[idx]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.replay_buffer._q_curr[index_set] = self.q_tau * q_curr + (1 - self.q_tau) * \
                                                                self.replay_buffer._q_curr[index_set]

                self.training_mode(False)

            self.pi_trainer.beta_prime_policy.load_state_dict(self.beta_trainer.beta_prime_policy.state_dict())
            if isinstance(self.beta_trainer.qf, list):
                for i in range(self.beta_trainer.num_q):
                    self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
            else:
                self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithm4sbpq(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        print("self.temp: ", temp)
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None
        self.q_tau = q_tau

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp

                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        (self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp), axis=0) for i in
                        range(len(index_set))])# - np.log(50) # 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])
                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array([len(self.cluster_idx_list[idx]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = self.q_tau * q_curr \
                                                            + (1 - self.q_tau) * self.replay_buffer._q_curr[index_set]

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithm4sbpt(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        print("self.temp: ", temp)
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None
        self.q_tau = q_tau

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp

                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        (self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp), axis=0) for i in
                        range(len(index_set))])# - np.log(50) # 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch']) # - log(cls_len) 곱하던말던
                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array(
                        [len(self.replay_buffer._q_curr[self.cluster_idx_list[idx]]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr
                    # self.replay_buffer._q_curr[index_set] = self.q_tau * q_curr \
                    #                                         + (1 - self.q_tau) * self.replay_buffer._q_curr[index_set]

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithm4sbptt(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            inner_step=1,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.update_step = update_step
        self.inner_step = inner_step
        self.exp_point = exp_point
        self.prob = None

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp
                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp, axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = train_data['q_curr'] - train_data['logsum_batch']
                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array(
                        [len(self.replay_buffer._q_curr[self.cluster_idx_list[idx]]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.beta_trainer.train(train_data)

                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4sbp2(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.update_soft_weight(self.replay_buffer, self.cluster_idx_list)

                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4sbp3(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    ###
                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp
                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp, axis=0) for i in
                        range(len(index_set))])  # -np.log(n) 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = train_data['q_curr'] - train_data['logsum_batch']
                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array(
                        [len(self.replay_buffer._q_curr[self.cluster_idx_list[idx]]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithm4sbpt(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        print("self.temp: ", temp)
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None
        self.q_tau = q_tau

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp

                    # B x N x 1 -> B x 1
                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        (self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp), axis=0) for i in
                        range(len(index_set))])# - np.log(50) # 해주면 평균이 1, 안해주면 sum이 1

                    ## half치타 용 temp 썌게, hopper 덜걸고
                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch']) # - log(cls_len) 곱하던말던
                    train_data['soft_weight'] = np.exp(train_data['diff'])
                    train_data['cls_len'] = np.expand_dims(np.array(
                        [len(self.replay_buffer._q_curr[self.cluster_idx_list[idx]]) for idx in index_set]), 1)

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr
                    # self.replay_buffer._q_curr[index_set] = self.q_tau * q_curr \
                    #                                         + (1 - self.q_tau) * self.replay_buffer._q_curr[index_set]

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithmv4(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            nn,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        print("self.temp: ", temp)
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None
        self.q_tau = q_tau
        self.nn = nn

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp

                    # B x N x 1 -> B x 1
                    train_data['cls_len'] = np.expand_dims(np.array([len(self.cluster_idx_list[idx]) for idx in index_set]), 1)

                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        (self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp), axis=0) for i in
                        range(len(index_set))])# - np.log(train_data['cls_len'])

                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])# # 하던말던
                    train_data['soft_weight'] = np.exp(train_data['diff'])

                    q_curr = self.pi_trainer.train(train_data)
                    self.replay_buffer._q_curr[index_set] = q_curr

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns

class BatchRLAlgorithmv4b(BaseRLAlgorithm2, metaclass=abc.ABCMeta):
    def __init__(
            self,
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            nn,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
            temp=1.0,
            update_step=1,
            q_tau=1.0,
    ):
        super().__init__(
            pi_trainer,
            beta_trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.temp = temp
        print("self.temp: ", temp)
        self.update_step = update_step
        self.exp_point = exp_point
        self.prob = None
        self.q_tau = q_tau
        self.nn = nn

        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.pi_trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            # state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.pi_trainer.policy(state)
            q1 = self.pi_trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.pi_trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.pi_trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.pi_trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for inner_epochs in range(self.num_trains_per_train_loop):
                    train_data, index_set = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    train_data['q_curr'] = self.replay_buffer._q_curr[index_set] / self.temp

                    # B x N x 1 -> B x 1
                    train_data['cls_len'] = np.expand_dims(np.array([len(self.cluster_idx_list[idx]) for idx in index_set]), 1)

                    train_data['logsum_batch'] = np.array([special.logsumexp(
                        (self.replay_buffer._q_curr[self.cluster_idx_list[index_set[i]]] / self.temp), axis=0) for i in
                        range(len(index_set))])# - np.log(train_data['cls_len'])

                    train_data['diff'] = (train_data['q_curr'] - train_data['logsum_batch'])# # 하던말던
                    train_data['soft_weight'] = np.exp(train_data['diff'])

                    q_curr = self.pi_trainer.train(train_data)
                    # self.replay_buffer._q_curr[index_set] = q_curr

                    self.beta_trainer.train(train_data)

                    if (epoch * self.num_trains_per_train_loop + inner_epochs) % self.update_step == 0 and epoch != 0:
                        self.pi_trainer.beta_prime_policy.load_state_dict(
                            self.beta_trainer.beta_prime_policy.state_dict())
                        if isinstance(self.beta_trainer.qf, list):
                            for i in range(self.beta_trainer.num_q):
                                self.pi_trainer.qf_bp[i].load_state_dict(self.beta_trainer.qf[i].state_dict())
                        else:
                            self.pi_trainer.qf_bp.load_state_dict(self.beta_trainer.qf.state_dict())

                self.training_mode(False)
            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.pi_trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4_2(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, idx = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    advs = self.trainer.train(train_data)
                    self.replay_buffer._advs[idx] = advs
                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns


class BatchRLAlgorithm4_3(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            q_learning_alg=False,
            eval_both=False,
            batch_rl=False,
            num_actions_sample=10,
            exp_point=None,
            cluster_idx_list=None,
            cluster_dist_list=None,
            obs_mean=0.0,
            obs_std=1.0,
            dist_temp=600,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.batch_size = batch_size
        self.cluster_idx_list = cluster_idx_list
        self.cluster_dist_list = cluster_dist_list
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.batch_rl = batch_rl
        self.q_learning_alg = q_learning_alg
        self.eval_both = eval_both
        self.num_actions_sample = num_actions_sample
        self.obs_mean = obs_mean
        self.obs_std = obs_std
        self.dist_temp = dist_temp
        self.exp_point = exp_point
        ### Reserve path collector for evaluation, visualization
        self._reserve_path_collector = MdpPathCollector(
            env=evaluation_env, policy=self.trainer.policy,
        )

    def policy_fn(self, obs):
        """
        Used when sampling actions from the policy and doing max Q-learning
        """
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            state = ptu.from_numpy(obs.reshape(1, -1)).repeat(self.num_actions_sample, 1)
            state = (state - self.env_mean) / self.env_std
            action, _, _, _, _, _, _, _ = self.trainer.policy(state)
            q1 = self.trainer.qf1(state, action)
            ind = q1.max(0)[1]
        return ptu.get_numpy(action[ind]).flatten()

    def policy_fn_discrete(self, obs):
        with torch.no_grad():
            obs = ptu.from_numpy(obs.reshape(1, -1))
            q_vector = self.trainer.qf1.q_vector(obs)
            action = q_vector.max(1)[1]
        ones = np.eye(q_vector.shape[1])
        return ptu.get_numpy(action).flatten()

    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data, idx = self.replay_buffer.random_batch(
                        self.batch_size, get_idx=True)

                    advs, z_diffs = self.trainer.train(train_data)

                    self.replay_buffer._advs[idx] = advs
                    self.replay_buffer._z_diffs[idx] = z_diffs

                self.training_mode(False)

            self._end_epoch(epoch)

    def _eval_q_custom_policy(self, custom_model, q_function):
        data_batch = self.replay_buffer.random_batch(self.batch_size)
        data_batch = np_to_pytorch_batch(data_batch)
        return self.trainer.eval_q_custom(custom_model, data_batch, q_function=q_function)

    def eval_policy_custom(self, policy):
        """Update policy and then look at how the returns under this policy look like."""
        self._reserve_path_collector.update_policy(policy)

        # Sampling
        eval_paths = self._reserve_path_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        # gt.stamp('evaluation during viz sampling')

        eval_returns = eval_util.get_average_returns(eval_paths)
        return eval_returns