# coding=utf-8
from ..ppo_trainer import PPOTrainer


class PPOPbrsTrainer(PPOTrainer):
    def __init__(self, state_space, action_space, algo_name="ppo_pbrs", **kwargs):
        super(PPOPbrsTrainer, self).__init__(state_space, action_space, algo_name, **kwargs)

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        # self.memory.add((s, a, r, sp, terminal))
        v_pred = kwargs.get("v_pred")

        """
            get the potential of s and s_n
            and compute the shaping reward
        """
        phi_s = kwargs.get("phi_s")
        phi_sp = kwargs.get("phi_sp")

        if self.last_exp is None:
            if phi_sp == "mujoco":
                self.last_exp = [s, a, r - phi_s, sp, terminal, v_pred]
            else:
                f_ssp = self.gamma * phi_sp - phi_s
                self.last_exp = [s, a, r + f_ssp, sp, terminal, v_pred]
        else:
            i = self.exp_cnt % self.truncation_size
            if phi_sp == "mujoco":
                """
                    if last experience is the last one of an episode
                """
                if not self.last_exp[4]:
                    self.last_exp[2] += self.gamma * phi_s

                self.exp_mini_buffer[i] = self.last_exp
                self.last_exp = [s, a, r - phi_s, sp, terminal, v_pred]
            else:
                self.exp_mini_buffer[i] = self.last_exp
                f_ssp = self.gamma * phi_sp - phi_s
                self.last_exp = [s, a, r + f_ssp, sp, terminal, v_pred]

            self.exp_cnt += 1
            if self.exp_cnt % self.truncation_size == 0:
                """
                    update the policy using the current experiences in buffer
                """
                self.ppo_update(next_v_pred=v_pred)
