import tensorflow as tf
import tensorflow_probability as tfp
import tree


from .base_algorithm import BaseAlgorithm


DEFAULT_OPTIMIZER = {
    'class_name': 'Adam',
    'config': {}
}


class BBORandomizedPrior(BaseAlgorithm):
    def __init__(self,
                 V_omega,
                 V_phi,
                 phi_lr=5e-1,
                 omega_lr=5e-2,
                 gamma=0.9,
                 prior_loc=0.0,
                 prior_scale=1.0,
                 gradient_limit=float('inf'),
                 num_phi_steps=1,
                 prior_loss_weight=1.0,
                 optimizer_params=None):
        self._phi_lr = phi_lr
        self._omega_lr = omega_lr
        self._gamma = gamma
        self._prior_scale = prior_scale
        self._prior_loc = prior_loc
        self._gradient_limit = gradient_limit
        self._num_phi_steps = num_phi_steps
        self._prior_loss_weight = prior_loss_weight
        optimizer_params = optimizer_params or DEFAULT_OPTIMIZER
        assert 'learning_rate' not in optimizer_params

        self.V_phi = V_phi
        self._phi_optimizer = tf.optimizers.get({
            'class_name': optimizer_params['class_name'],
            'config': {
                **optimizer_params['config'],
                'learning_rate': phi_lr,
            },
        })

        self.V_omega = V_omega
        self._omega_optimizer = tf.optimizers.get({
            'class_name': 'SGD',
            'config': {
                **optimizer_params['config'],
                'learning_rate': omega_lr,
            },
        })

        self.prior = tfp.distributions.Normal(
            loc=prior_loc,
            scale=prior_scale)

    @property
    def V(self):
        return self.V_omega

    @tf.function(experimental_relax_shapes=True)
    def update_V(self, state_0s, actions, state_1s, rewards, terminals, rhos):
        rewards = tf.cast(rewards, self.V_omega.model.dtype)

        V_omega_s1 = self.V_omega.values(state_1s)
        continuation_probs = self._gamma * (
            1.0 - tf.cast(terminals, tf.float32))
        target = rewards + continuation_probs * V_omega_s1

        def train_V_phi(i):
            with tf.GradientTape() as tape:
                V_phi_s0 = self.V_phi.values(state_0s)
                prior_loss = 0.5 * tf.reduce_sum(tree.map_structure(
                    lambda phi: tf.reduce_sum(
                        tf.math.squared_difference(
                            phi, self.prior.sample(tf.shape(phi)))),
                    tree.flatten(self.V_phi.trainable_variables)))
                td_losses = 0.5 * tf.reduce_sum(
                    tf.math.squared_difference(V_phi_s0, target))
                phi_losses = tf.reduce_sum(
                    td_losses + self._prior_loss_weight * prior_loss)

            phi_gradients = tape.gradient(
                phi_losses, self.V_phi.trainable_variables)
            phi_gradients = tree.map_structure(
                lambda x: tf.clip_by_value(
                    x, -self._gradient_limit, self._gradient_limit),
                phi_gradients)
            self._phi_optimizer.apply_gradients(
                zip(phi_gradients, self.V_phi.trainable_variables))

            return tf.reduce_mean(td_losses), prior_loss

        td_losses, prior_losses = tf.map_fn(
            train_V_phi,
            tf.range(self._num_phi_steps),
            dtype=(tf.float32, tf.float32),
            parallel_iterations=1)

        omega_losses = 0.0
        for omega_weight, phi_weight in zip(
                self.V_omega.trainable_variables,
                self.V_phi.trainable_variables):
            omega_gradients = omega_weight - phi_weight
            self._omega_optimizer.apply_gradients([
                (omega_gradients, omega_weight)])
            omega_losses += tf.reduce_mean(omega_gradients)

        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'phi'),
            self.V_phi.trainable_variables)
        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'omega'),
            self.V_omega.trainable_variables)

        return {
            'td_loss': tf.reduce_mean(td_losses),
            'prior_loss': tf.reduce_mean(prior_losses),
            'omega_loss': tf.reduce_mean(omega_losses),
            'phi_norm': tf.reduce_mean(tree.map_structure(
                tf.linalg.norm,
                self.V_phi.trainable_variables)),
            'omega_norm': tf.reduce_mean(tree.map_structure(
                tf.linalg.norm,
                self.V_omega.trainable_variables)),
        }
