import torch
import torch.nn as nn
import numpy as np
import common_utils


def policy_gradient_error(pi, q, action, behavior_pi, log_prob_func, max_log_ratio):
    """Compute policy gradient error of policy

    params:
        pi: [batch_size, num_action] policy being evaluated
        q: [batch_size] estimate of state-action value
        action: [batch_size, num_action] action taken
        behavior_pi: [batch_size, num_action]
            behavior policy used during experience collection None if on-policy
    """
    action_logp = log_prob_func(pi, action)

    if behavior_pi is not None:
        behavior_action_logp = log_prob_func(behavior_pi, action)
        log_ratio = action_logp.detach() - behavior_action_logp
        log_ratio = log_ratio.clamp(max=max_log_ratio)
        ratio = log_ratio.exp()
        q = q * ratio

    policy_err = -q * action_logp
    return policy_err


def ppo_gradient_error(pi, q, action, behavior_pi, log_prob_func, ratio_clamp):
    """Compute policy gradient error of policy
    params:
        pi: [batch_size, num_action] policy being evaluated
        q: [batch_size] estimate of state-action value
        action: [batch_size, num_action] action taken
        behavior_pi: [batch_size, num_action]
            behavior policy used during experience collection None if on-policy
    """
    # utils.assert_eq(action.size(), pi.size())
    # utils.assert_eq(q.size(), (pi.size(0),))

    # if behavior_pi is not None:
    #     pass
    #     # utils.assert_eq(behavior_pi.size(), pi.size())
    action_logp = log_prob_func(pi, action)

    assert behavior_pi is not None

    behavior_action_logp = log_prob_func(behavior_pi, action)
    log_ratio = action_logp - behavior_action_logp
    ratio = log_ratio.exp()
    clamped_ratio = ratio.clamp(min=1-ratio_clamp, max=1+ratio_clamp)

    err1 = q * ratio
    err2 = q * clamped_ratio
    err = -torch.min(err1, err2)
    return err


class DiscountedReward:
    def __init__(self, gamma):
        self.gamma = gamma

    def setR(self, R, stats):
        """Set rewards and feed to stats."""
        self.R = R
        # print('R size', R.size())
        stats['v'].feed(R.mean().detach())

    def feed(self, r, terminal, stats):
        self.R = r + (1 - terminal) * self.R * self.gamma
        stats['r'].feed(r.mean().detach())
        stats['g'].feed(self.R.mean().detach())
        return self.R


class ActorCritic:
    def __init__(self, *,
                 ent_ratio,
                 min_prob,
                 max_importance_ratio,
                 ratio_clamp,
                 gamma,
                 ppo,
                 use_xent):
        self.ent_ratio = ent_ratio
        self.min_prob = min_prob
        self.gamma = gamma
        self.ppo = ppo
        self.use_xent = use_xent
        if ppo:
            assert ratio_clamp < 0.5
            # self.min_ratio = np.log(1 - ratio_clamp)
            # self.max_ratio = np.log(1 + ratio_clamp)
            self.ratio_clamp = ratio_clamp
        else:
            assert max_importance_ratio >= 1
            self.max_log_ratio = np.log(max_importance_ratio)

        self.discounted_reward = DiscountedReward(gamma)

    def compute_gradient(self, model, critic, batch, stats):
        """Actor critic model update.

        Feed stats for later summarization.
        """
        # print('------------train-------------')
        sampler = model.sampler
        T = model.get_reward(batch).size(1)

        inputs = common_utils.tensor_index(batch, 1, T)
        if critic is None:
            bootstrap_v = model.get_value(model(inputs))
        else:
            bootstrap_v = critic.get_value(critic(inputs))
        self.discounted_reward.setR(bootstrap_v.detach(), stats)

        for t in range(T - 1, -1, -1):
            # go through the sample and get the rewards.
            r = common_utils.tensor_index(model.get_reward(batch), 1, t).squeeze(1)
            r = r.clamp(min=0)
            terminal = common_utils.tensor_index(model.get_terminal(batch), 1, t).squeeze(1)
            g = self.discounted_reward.feed(r, terminal, stats)
            inputs = common_utils.tensor_index(batch, 1, t)

            if critic is None:
                outputs = model(inputs)
                pi = model.get_policy(outputs)
                v = model.get_value(outputs)
            else:
                assert False
                pi = model.get_policy(model(inputs))
                v = critic.get_value(critic(inputs))

            # import ipdb; ipdb.set_trace()

            q = g.detach() - v
            # import ipdb; ipdb.set_trace()
            # print('q:', q)
            value_err = 0.5 * q.pow(2)

            pi = sampler.clamp_prob(pi, self.min_prob)
            behavior_pi = common_utils.tensor_index(model.get_policy(batch), 1, t)
            action = common_utils.tensor_index(model.get_action(batch), 1, t)

            if self.ppo:
                policy_err = ppo_gradient_error(
                    pi,
                    q.detach(),
                    action,
                    behavior_pi,
                    sampler.get_log_prob,
                    self.ratio_clamp)
            else:
                policy_err = policy_gradient_error(
                    pi,
                    q.detach(),
                    action,
                    behavior_pi,
                    sampler.get_log_prob,
                    self.max_log_ratio)

            xent, kl = sampler.get_cross_entropy(pi, inputs, self.min_prob)
            ent = sampler.get_entropy(pi)
            # print('kl:', kl.mean().detach().item())
            stats['xent'].feed(xent.mean().detach().item())
            stats['ent'].feed(ent.mean().detach().item())
            stats['kl'].feed(kl.mean().detach().item())

            if self.use_xent:
                ent_reg = -xent
            else:
                ent_reg = ent

            common_utils.assert_eq(value_err.dim(), 1)
            common_utils.assert_eq(value_err.size(), policy_err.size())
            common_utils.assert_eq(value_err.size(), ent_reg.size())
            value_err = value_err.mean()
            policy_err = policy_err.mean()
            ent_reg = ent_reg.mean()

            stats["err_val"].feed(value_err.detach().item())
            stats["err_pi"].feed(policy_err.detach().item())
            # stats["ent_reg"].feed(entropy.detach().item())

            err = 1 / T * (value_err + policy_err - self.ent_ratio * ent_reg)
            err.backward()
            stats["cost"].feed(err.detach().item())
