import gym.spaces as spaces
import numpy as np
from rlberry.agents import AgentWithSimplePolicy
from .tree import MDPTreePartition

import rlberry

logger = rlberry.logger


class AdaptiveRandQLAgent(AgentWithSimplePolicy):
    """
    Adaptive Randomized Q-Learning algorithm implemented for enviroments
    with continuous (Box) states and **discrete actions**.
    Parameters
    ----------
    env : gym.Env
        Environment with discrete states and actions.
    gamma : double, default: 1.0
        Discount factor in [0, 1].
    horizon : int
        Horizon of the objective function.
    References
    ----------
    NA
    
    Notes
    ------
    Uses the metric induced by the l-infinity norm.
    """

    name = "AdaptiveRandomizedQLearning"

    def __init__(
        self,
        env,
        gamma=1.0,
        horizon=50,
        bootstrap_samples=10,
        prior_transitions=1,
        kappa=1.0,
        **kwargs
    ):
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.gamma = gamma
        self.horizon = horizon
        self.bootstrap_samples = bootstrap_samples
        self.prior_transitions = prior_transitions
        self.kappa = kappa

        # maximum value
        r_range = self.env.reward_range[1] - self.env.reward_range[0]
        if r_range == np.inf or r_range == 0.0:
            logger.warning(
                "{}: Reward range is  zero or infinity. ".format(self.name)
                + "Setting it to 1."
            )
            r_range = 1.0

        self.v_max = np.zeros(self.horizon)
        self.v_max[-1] = r_range
        for hh in reversed(range(self.horizon - 1)):
            self.v_max[hh] = r_range + self.gamma * self.v_max[hh + 1]

        self.reset()

    def reset(self):
        self.Qtree = MDPTreePartition(
            self.env.observation_space, self.env.action_space,
            self.horizon, self.bootstrap_samples
        )

        # info
        self.episode = 0

    def policy(self, observation):
        action, _ = self.Qtree.get_argmax_and_node(observation, 0)
        return action

    def _get_action_and_node(self, observation, hh):
        action, node = self.Qtree.get_argmax_and_node(observation, hh)
        return action, node

    def _update(self, node, state, action, next_state, reward, hh):
        # split node if necessary
        node_to_check = self.Qtree.update_counts(state, action, hh)
        if node_to_check.n_visits >= (self.Qtree.dmax / node_to_check.radius) ** 2.0:
            node_to_check.split()
        assert id(node_to_check) == id(node)

        tt = node.n_visits  # number of visits to the selected state-action node

        # value at next_state
        value_next_state = 0
        if hh < self.horizon - 1:
            value_next_state = min(
                self.v_max[hh + 1],
                self.Qtree.get_argmax_and_node(next_state, hh + 1)[1].qvalue,
            )

        # compute target
        alpha_prior = tt
        beta_prior = self.prior_transitions
        weights_prior = self.rng.beta(
            alpha_prior, 
            beta_prior,
            size=self.bootstrap_samples
            )
        
        target = weights_prior * (reward + self.gamma * value_next_state)
        target += (1.0 - weights_prior) * self.v_max[hh]

        # sample random learning rate
        alpha = (self.horizon + 1.0) / self.kappa
        beta = tt / self.kappa
        weights = self.rng.beta(
            alpha,
            beta, 
            size=self.bootstrap_samples
            )

        # update Q
        node.qvalue_tilde = weights * target + (1-weights) * node.qvalue_tilde 
        node.qvalue = node.qvalue_tilde.max()

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for hh in range(self.horizon):
            action, node = self._get_action_and_node(state, hh)
            next_state, reward, done, _ = self.env.step(action)
            episode_rewards += reward

            self._update(node, state, action, next_state, reward, hh)

            state = next_state
            if done:
                break

        # update info
        self.episode += 1

        # writer
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards, self.episode)

        # return sum of rewards collected in the episode
        return episode_rewards

    def fit(self, budget: int, **kwargs):
        """
        Train the agent using the provided environment.
        Parameters
        ----------
        budget: int
            number of episodes. Each episode runs for self.horizon unless it
            enconters a terminal state in which case it stops early.
        """
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1