import random
from typing import Callable, List, Optional
from functools import partial
import os
import cvxpy as cp
from cvxpylayers.jax import CvxpyLayer
import gymnasium as gym
import numpy as np
import jax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
from flax.training import checkpoints
import orbax

import wandb as wb
from gpi.dynamics.ensemble_model_jax import ProbabilisticEnsemble
from gpi.dynamics.util_jax import ModelEnv
from gpi.rl_algorithm import RLAlgorithm
from gpi.utils.buffer import ReplayBuffer
from gpi.utils.eval import eval_mo, visualize_eval_jax
from gpi.utils.prioritized_buffer import PrioritizedReplayBuffer
from gpi.utils.utils import linearly_decaying_epsilon, random_weights


class Psi(nn.Module):
    action_dim: int
    rew_dim: int
    dropout_rate: Optional[float] = 0.01
    use_layer_norm: bool = True
    num_hidden_layers: int = 4
    hidden_dim: int = 256

    @nn.compact
    def __call__(self, obs: jnp.ndarray, w: jnp.ndarray, deterministic: bool):
        h_obs = nn.Dense(self.hidden_dim)(obs)
        h_obs = nn.relu(h_obs)

        h_w = nn.Dense(self.hidden_dim)(w)
        h_w = nn.relu(h_w)

        h = h_obs * h_w
        for _ in range(self.num_hidden_layers - 1):
            h = nn.Dense(self.hidden_dim)(h)
            if self.dropout_rate is not None and self.dropout_rate > 0:
                h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=deterministic)
            if self.use_layer_norm:
                h = nn.LayerNorm()(h)
            h = nn.relu(h)
        x = nn.Dense(self.action_dim * self.rew_dim)(h)
        return x

class VectorPsi(nn.Module):
    action_dim: int
    rew_dim: int
    use_layer_norm: bool = True
    dropout_rate: Optional[float] = 0.01
    n_critics: int = 2
    num_hidden_layers: int = 4
    hidden_dim: int = 256

    @nn.compact
    def __call__(self, obs: jnp.ndarray, w: jnp.ndarray, deterministic: bool):
        vmap_critic = nn.vmap(
            Psi,
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True, "dropout": True},  # different initializations
            in_axes=None,
            out_axes=0,
            axis_size=self.n_critics,
        )
        q_values = vmap_critic(
            action_dim=self.action_dim,
            rew_dim=self.rew_dim,
            dropout_rate=self.dropout_rate,
            use_layer_norm=self.use_layer_norm,
            num_hidden_layers=self.num_hidden_layers,
            hidden_dim=self.hidden_dim,
            )(obs, w, deterministic)
        return q_values.reshape((self.n_critics, -1, self.action_dim, self.rew_dim))

class TrainState(TrainState):
    target_params: flax.core.FrozenDict


class USFA(RLAlgorithm):
    def __init__(
        self,
        env,
        learning_rate: float = 3e-4,
        initial_epsilon: float = 0.01,
        final_epsilon: float = 0.01,
        epsilon_decay_steps: int = None,  # None == fixed epsilon
        tau: float = 1.0,
        target_net_update_freq: int = 1000,  # ignored if tau != 1.0
        buffer_size: int = int(1e6),
        net_arch: List = [256, 256],
        num_nets: int = 1,
        batch_size: int = 256,
        learning_starts: int = 100,
        gradient_updates: int = 1,
        gamma: float = 0.99,
        max_grad_norm: Optional[float] = None,
        use_gpi: bool = True,
        h_step: int = 1,
        gpi_type: str = "gpi",
        dyna: bool = False,
        per: bool = False,
        gper: bool = False,
        alpha_per: float = 0.6,
        min_priority: float = 1.0,
        drop_rate: float = 0.01,
        layer_norm: bool = True,
        dynamics_ensemble_size: int = 7,
        dynamics_num_elites: int = 5,
        dynamics_normalize_inputs: bool = False,
        dynamics_uncertainty_threshold: float = 0.5,
        dynamics_train_freq: Callable = lambda x: 1000,
        dynamics_rollout_len: int = 1,
        dynamics_rollout_starts: int = 5000,
        dynamics_rollout_freq: int = 250,
        dynamics_rollout_batch_size: int = 10000,
        dynamics_buffer_size: int = 400000,
        dynamics_net_arch: List = [200, 200, 200, 200],
        real_ratio: float = 0.05,
        seed: int = 0,
        project_name: str = "usfa",
        experiment_name: str = "usfa",
        log: bool = True,
        device = None
        ):
        super().__init__(env, experiment_name=experiment_name, project_name=project_name, device=device)
        self.phi_dim = len(self.env.w)
        self.learning_rate = learning_rate
        self.initial_epsilon = initial_epsilon
        self.epsilon = initial_epsilon
        self.epsilon_decay_steps = epsilon_decay_steps
        self.final_epsilon = final_epsilon
        self.tau = tau
        self.target_net_update_freq = target_net_update_freq
        self.gamma = gamma
        self.max_grad_norm = max_grad_norm
        self.use_gpi = use_gpi
        self.cgpi_layer = None
        self.min_phi = -np.ones(self.phi_dim)
        self.gpi_type = gpi_type
        self.include_w = False
        self.h_step = h_step
        self.buffer_size = buffer_size
        self.net_arch = net_arch
        self.dynamics_net_arch = dynamics_net_arch
        self.learning_starts = learning_starts
        self.batch_size = batch_size
        self.gradient_updates = gradient_updates
        self.num_nets = num_nets
        self.drop_rate = drop_rate
        self.layer_norm = layer_norm

        key = jax.random.PRNGKey(seed)
        self.key, psi_key, dropout_key = jax.random.split(key, 3)

        obs = env.observation_space.sample()
        self.psi = VectorPsi(self.action_dim, self.phi_dim, self.layer_norm, self.drop_rate, self.num_nets, num_hidden_layers=len(self.net_arch), hidden_dim=self.net_arch[0])
        self.psi_state = TrainState.create(
            apply_fn=self.psi.apply,
            params=self.psi.init(
                {"params": psi_key, "dropout": dropout_key},
                obs,
                env.w,
                deterministic=False,
            ),
            target_params=self.psi.init(
                {"params": psi_key, "dropout": dropout_key},
                obs,
                env.w,
                deterministic=False,
            ),
            tx=optax.adam(learning_rate=self.learning_rate),
        )
        self.psi.apply = jax.jit(self.psi.apply, static_argnames=("dropout_rate", "use_layer_norm", "deterministic"))

        self.per = per
        self.gper = gper
        if self.per:
            self.replay_buffer = PrioritizedReplayBuffer(self.observation_shape, 1, rew_dim=self.phi_dim, max_size=buffer_size, action_dtype=np.uint8)
        else:
            self.replay_buffer = ReplayBuffer(self.observation_shape, 1, rew_dim=self.phi_dim, max_size=buffer_size, action_dtype=np.uint8)
        self.min_priority = min_priority
        self.alpha = alpha_per
        self.M = []

        self.dyna = dyna
        if self.dyna:
            self.dynamics = ProbabilisticEnsemble(
                input_dim=self.observation_dim + self.action_dim,
                output_dim=self.observation_dim + self.phi_dim,
                arch=self.dynamics_net_arch,
                normalize_inputs=dynamics_normalize_inputs,
                ensemble_size=dynamics_ensemble_size,
                num_elites=dynamics_num_elites,
            )
            self.key = self.dynamics.build(self.key)
            self.dynamics_buffer = ReplayBuffer(self.observation_shape, 1, rew_dim=self.phi_dim, max_size=dynamics_buffer_size, action_dtype=np.uint8)
        self.dynamics_train_freq = dynamics_train_freq
        self.dynamics_ensemble_size = dynamics_ensemble_size
        self.dynamics_num_elites = dynamics_num_elites
        self.dynamics_rollout_len = dynamics_rollout_len
        self.dynamics_rollout_starts = dynamics_rollout_starts
        self.dynamics_rollout_freq = dynamics_rollout_freq
        self.dynamics_rollout_batch_size = dynamics_rollout_batch_size
        self.dynamics_uncertainty_threshold = dynamics_uncertainty_threshold
        self.dynamics_normalize_inputs = dynamics_normalize_inputs
        self.real_ratio = real_ratio

        self.log = log
        if log:
            self.setup_wandb(project_name, experiment_name)

    def get_config(self):
        return {
            "env_id": self.env.unwrapped.spec.id,
            "learning_rate": self.learning_rate,
            "initial_epsilon": self.initial_epsilon,
            "epsilon_decay_steps:": self.epsilon_decay_steps,
            "batch_size": self.batch_size,
            "per": self.per,
            "alpha_per": self.alpha,
            "min_priority": self.min_priority,
            "tau": self.tau,
            "num_nets": self.num_nets,
            "clip_grand_norm": self.max_grad_norm,
            "target_net_update_freq": self.target_net_update_freq,
            "gamma": self.gamma,
            "net_arch": self.net_arch,
            "model_arch": self.dynamics_net_arch,
            "gradient_updates": self.gradient_updates,
            "buffer_size": self.buffer_size,
            "learning_starts": self.learning_starts,
            "dyna": self.dyna,
            "dynamics_ensemble_size": self.dynamics_ensemble_size,
            "dynamics_num_elites": self.dynamics_num_elites,
            "dynamics_rollout_len": self.dynamics_rollout_len,
            "dynamics_uncertainty_threshold": self.dynamics_uncertainty_threshold,
            "dynamics_normalize_inputs": self.dynamics_normalize_inputs,
            "real_ratio": self.real_ratio,
            "drop_rate": self.drop_rate,
            "layer_norm": self.layer_norm,
        }

    def save(self, save_dir="weights/", filename=None):
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        saved_params = {}
        saved_params["psi_net_state"] = self.psi_state
        saved_params["M"] = self.M
        if self.dyna:
            saved_params.update(self.dynamics.get_params())

        filename = self.experiment_name if filename is None else filename
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        checkpoints.save_checkpoint(ckpt_dir=save_dir + filename,
                            target=saved_params,
                            step=self.num_timesteps,
                            overwrite=True,
                            keep=2,
                            orbax_checkpointer=orbax_checkpointer)

    def load(self, path, step=None):
        target = {"psi_net_state": self.psi_state, "M": self.M}
        if self.dyna:
            target.update(self.dynamics.get_params())
        restored = checkpoints.restore_checkpoint(ckpt_dir=path, target=None, step=step)
        target['M'] = restored['M'] # for some reason I need to do this
        if self.dyna:
            target['elites'] = restored['elites']
            target['inputs_mu'] = restored['inputs_mu']
            target['inputs_sigma'] = restored['inputs_sigma']
        restored = checkpoints.restore_checkpoint(ckpt_dir=path, target=target, step=step)
        self.psi_state = restored["psi_net_state"]
        self.M = restored["M"]
        if self.dyna:
            self.dynamics.ensemble_state = restored["ensemble_state"]
            self.dynamics.elites = restored["elites"]
            self.inputs_mu = restored["inputs_mu"]
            self.inputs_sigma = restored["inputs_sigma"]
        self.cgpi_layer = None  # reset cgpi layer

    def sample_batch_experiences(self):
        if not self.dyna or self.num_timesteps < self.dynamics_rollout_starts or len(self.dynamics_buffer) == 0:
            return self.replay_buffer.sample(self.batch_size, to_tensor=False, device=self.device)
        else:
            num_real_samples = int(self.batch_size * self.real_ratio)  # real_ratio% of real world data
            if self.per:
                s_obs, s_actions, s_rewards, s_next_obs, s_dones, idxes = self.replay_buffer.sample(num_real_samples, to_tensor=False, device=self.device)
            else:
                s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.replay_buffer.sample(num_real_samples, to_tensor=False, device=self.device)
            m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.dynamics_buffer.sample(self.batch_size - num_real_samples, to_tensor=False, device=self.device)
            experience_tuples = (
                np.concatenate([s_obs, m_obs], axis=0),
                np.concatenate([s_actions, m_actions], axis=0),
                np.concatenate([s_rewards, m_rewards], axis=0),
                np.concatenate([s_next_obs, m_next_obs], axis=0),
                np.concatenate([s_dones, m_dones], axis=0),
            )
            if self.per:
                return experience_tuples + (idxes,)
            return experience_tuples

    @staticmethod
    @partial(jax.jit, static_argnames=["psi", "return_q_values"])
    def batch_gpi(psi, psi_state, obs, w, M, key, return_q_values=False):
        M_stack = jnp.stack(M)
        M_stack = M_stack.reshape(1, M_stack.shape[0], M_stack.shape[1]).repeat(len(obs), axis=0)
        obs_m = obs.reshape(obs.shape[0], 1, obs.shape[1]).repeat(M_stack.shape[1], axis=1)

        psi_values = psi.apply(psi_state.params, obs_m, M_stack, deterministic=True)
        q_values = (psi_values * w).sum(axis=3).reshape(psi_values.shape[0], obs.shape[0], len(M), -1)
        q_values = q_values.mean(axis=0)

        max_q = jnp.max(q_values, axis=2)
        pi = jnp.argmax(max_q, axis=1)
        best_q_values = q_values[jnp.arange(q_values.shape[0]), pi]
        acts = best_q_values.argmax(axis=1)

        if return_q_values:
            return acts, best_q_values[jnp.arange(q_values.shape[0]), acts], key

        return acts, key

    def rollout_dynamics(self, w):
        # Dyna Planning
        num_times = int(np.ceil(self.dynamics_rollout_batch_size / 10000))
        batch_size = min(self.dynamics_rollout_batch_size, 10000)
        num_added_imagined_transitions = 0
        for iteration in range(num_times):
            obs = self.replay_buffer.sample_obs(batch_size, to_tensor=False)
            model_env = ModelEnv(self.dynamics, self.env.unwrapped.spec.id, rew_dim=len(w))

            for h in range(self.dynamics_rollout_len):
                actions, self.key = USFA.batch_gpi(self.psi, self.psi_state, obs, w, self.M, self.key)
                actions_one_hot = nn.one_hot(actions, num_classes=self.action_dim)

                next_obs_pred, r_pred, dones, info = model_env.step(obs, actions_one_hot, deterministic=False)
                uncertainties = info['uncertainty']
                obs, actions = jax.device_get(obs), jax.device_get(actions)

                for i in range(len(obs)):
                    if uncertainties[i] < self.dynamics_uncertainty_threshold:
                        self.dynamics_buffer.add(obs[i], actions[i], r_pred[i], next_obs_pred[i], dones[i])
                        num_added_imagined_transitions += 1

                nonterm_mask = ~dones.squeeze(-1)
                if nonterm_mask.sum() == 0:
                    break
                obs = next_obs_pred[nonterm_mask]

        if self.log:
            self.writer.add_scalar("dynamics/uncertainty_mean", uncertainties.mean(), self.num_timesteps)
            self.writer.add_scalar("dynamics/uncertainty_max", uncertainties.max(), self.num_timesteps)
            self.writer.add_scalar("dynamics/uncertainty_min", uncertainties.min(), self.num_timesteps)
            self.writer.add_scalar("dynamics/model_buffer_size", len(self.dynamics_buffer), self.num_timesteps)
            self.writer.add_scalar("dynamics/imagined_transitions", num_added_imagined_transitions, self.num_timesteps)

    @staticmethod
    @partial(jax.jit, static_argnames=["psi", "gamma", "min_priority"])
    def update(psi, psi_state, w, obs, actions, rewards, next_obs, dones, gamma, min_priority, key):
        key, inds_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)

        # DroQ update
        if psi.n_critics >= 2:
            psi_values_next = psi.apply(psi_state.target_params, next_obs, w, deterministic=False, rngs={"dropout": dropout_key_target})
            if psi_values_next.shape[0] > 2:
                inds = jax.random.randint(inds_key, (2,), 0, psi_values_next.shape[0])
                psi_values_next = psi_values_next[inds]
            q_values_next = (psi_values_next * w.reshape(w.shape[0], 1, w.shape[1])).sum(axis=3)
            min_inds = q_values_next.argmin(axis=0)
            min_psi_values = jnp.take_along_axis(psi_values_next, min_inds[None,...,None], 0).squeeze(0)
            
            max_q = (min_psi_values * w.reshape(w.shape[0], 1, w.shape[1])).sum(axis=2)
            max_acts = max_q.argmax(axis=1)
            target = min_psi_values[jnp.arange(min_psi_values.shape[0]), max_acts]

            def mse_loss(params, droptout_key):
                psi_values = psi.apply(params, obs, w, deterministic=False, rngs={"dropout": droptout_key})
                psi_values = psi_values[:, jnp.arange(psi_values.shape[1]), actions.squeeze()]
                tds = psi_values - target_psi
                loss = jnp.abs(tds)
                loss = jnp.where(loss < min_priority, 0.5 * loss ** 2, loss * min_priority)
                return loss.mean(), tds
        # DDQN update
        else:
            psi_values_next = psi.apply(psi_state.target_params, next_obs, w, deterministic=True)[0]
            psi_values_not_target = psi.apply(psi_state.params, next_obs, w, deterministic=True)
            q_values_next = (psi_values_not_target * w.reshape(w.shape[0], 1, w.shape[1])).sum(axis=3)[0]
            max_acts = q_values_next.argmax(axis=1)
            target = psi_values_next[jnp.arange(psi_values_next.shape[0]), max_acts]

            def mse_loss(params, droptout_key):
                psi_values = psi.apply(params, obs, w, deterministic=True)
                psi_values = psi_values[:, jnp.arange(psi_values.shape[1]), actions.squeeze()]
                tds = psi_values - target_psi
                loss = jnp.abs(tds)
                loss = jnp.where(loss < min_priority, 0.5 * loss ** 2, loss * min_priority)
                return loss.mean(), tds

        target_psi = rewards + (1 - dones) * gamma * target

        (loss_value, td_error), grads = jax.value_and_grad(mse_loss, has_aux=True)(psi_state.params, dropout_key_current)
        psi_state = psi_state.apply_gradients(grads=grads)

        return psi_state, loss_value, td_error, key

    def train(self, weight):
        critic_losses = []
        for g in range(self.gradient_updates if self.num_timesteps >= self.dynamics_rollout_starts else 1):
            if self.per:
                s_obs, s_actions, s_rewards, s_next_obs, s_dones, idxes = self.sample_batch_experiences()
            else:
                s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.sample_batch_experiences()

            if len(self.M) > 1:
                s_obs, s_actions, s_rewards, s_next_obs, s_dones = np.vstack([s_obs]*2), np.vstack([s_actions]*2), np.vstack([s_rewards]*2), np.vstack([s_next_obs]*2), np.vstack([s_dones]*2)
                w = np.vstack([weight for _ in range(s_obs.shape[0] // 2)] + random.choices(self.M, k=s_obs.shape[0] // 2))
            else:
                w = weight.repeat(s_obs.shape[0], 1)

            self.key, w_sample = jax.random.split(self.key)
            w += jax.random.normal(w_sample, w.shape, dtype=jnp.float32) * 0.1

            self.psi_state, loss, td_error, self.key = USFA.update(self.psi, self.psi_state, w, s_obs, s_actions, s_rewards, s_next_obs, s_dones, self.gamma, self.min_priority, self.key)
            critic_losses.append(loss.item())

            if self.per:
                td_error = jax.device_get(td_error)
                td_error = np.abs((td_error[:,: len(idxes)] * w[: len(idxes)]).sum(axis=2))
                per = np.max(td_error, axis=0)
                priority = per.clip(min=self.min_priority)**self.alpha
                self.replay_buffer.update_priorities(idxes, priority)

        if self.tau != 1 or self.num_timesteps % self.target_net_update_freq == 0:
            self.psi_state = USFA.target_net_update(self.psi_state)

        if self.epsilon_decay_steps is not None:
            self.epsilon = linearly_decaying_epsilon(self.initial_epsilon, self.epsilon_decay_steps, self.num_timesteps, self.learning_starts, self.final_epsilon)

        if self.log and self.num_timesteps % 100 == 0:
            if self.per:
                self.writer.add_scalar("metrics/mean_priority", np.mean(priority), self.num_timesteps)
                self.writer.add_scalar("metrics/max_priority", np.max(priority), self.num_timesteps)
                self.writer.add_scalar("metrics/mean_td_error_w", np.mean(per), self.num_timesteps)
            self.writer.add_scalar("losses/critic_loss", np.mean(critic_losses), self.num_timesteps)
            self.writer.add_scalar("metrics/epsilon", self.epsilon, self.num_timesteps)

    @staticmethod
    @jax.jit
    def target_net_update(psi_state):
        psi_state = psi_state.replace(target_params=optax.incremental_update(psi_state.params, psi_state.target_params, 1))
        return psi_state

    @staticmethod
    @partial(jax.jit, static_argnames=["psi", "return_policy_index"])
    def gpi_action(psi, psi_state, obs, w, M, key, return_policy_index=False):
        M = jnp.stack(M)
        
        obs_m = obs.reshape(1,-1).repeat(M.shape[0], axis=0)
        psi_values = psi.apply(psi_state.params, obs_m, M, deterministic=True)
        q_values = (psi_values * w.reshape(1, 1, 1, w.shape[0])).sum(axis=3)
        
        q_values = q_values.mean(axis=0)

        max_q = q_values.max(axis=1)
        policy_index = max_q.argmax()  # max_i max_a q(s,a,w_i)
        action = q_values[policy_index].argmax()

        if return_policy_index:
            return action, policy_index, key
        return action, key

    @staticmethod
    @partial(jax.jit, static_argnames=["psi", "model", "normalize_inputs", "env_id", "h_step", "gamma"])
    def hstep_gpi_action(psi, psi_state, model, model_state, inputs_mu, inputs_sigma, normalize_inputs, elites, obs, w, M, env_id, h_step, gamma, key):
        action_dim = model.input_dim - obs.shape[0]
        rew_dim = w.shape[0]
        obs = obs.reshape(1, -1)

        returns = jnp.zeros(1)
        for k in range(h_step):
            actions = jnp.eye(action_dim).repeat(obs.shape[0], axis=0)
            obs = jnp.tile(obs, (action_dim, 1))
            obs_actions = jnp.concatenate([obs, actions], axis=1)

            sample = ProbabilisticEnsemble.forward(model, model_state, obs_actions, inputs_mu, inputs_sigma, normalize_inputs=normalize_inputs, deterministic=True, return_dist=False, key=key)
            sample_mean, sample_std = sample.mean(axis=0), sample.std(axis=0)
            rewards, next_obs = sample_mean[:, : rew_dim], sample_mean[:, rew_dim :]

            next_obs += obs
            returns = jnp.tile(returns, (action_dim, 1)).flatten()
            returns += gamma**k * (rewards * w).sum(axis=1)
            obs = next_obs

        next_actions, q_values, key = USFA.batch_gpi(psi, psi_state, obs, w, M, key, return_q_values=True)

        returns += gamma**h_step * q_values
        trajectory_ind = returns.argmax(axis=0)
        best_action = trajectory_ind % action_dim

        return best_action, key

    @staticmethod
    @partial(jax.jit, static_argnames=["model", "normalize_inputs", "env_id", "h_step", "gamma"])
    def mpc_action(model, model_state, inputs_mu, inputs_sigma, normalize_inputs, elites, obs, w, env_id, h_step, gamma, key):
        action_dim = model.input_dim - obs.shape[0]
        rew_dim = w.shape[0]
        obs = obs.reshape(1, -1)

        returns = jnp.zeros(1)
        for k in range(h_step):
            actions = jnp.eye(action_dim).repeat(obs.shape[0], axis=0)
            obs = jnp.tile(obs, (action_dim, 1))
            obs_actions = jnp.concatenate([obs, actions], axis=1)

            sample = ProbabilisticEnsemble.forward(model, model_state, obs_actions, inputs_mu, inputs_sigma, normalize_inputs=normalize_inputs, deterministic=True, return_dist=False, key=key)
            sample_mean, sample_std = sample.mean(axis=0), sample.std(axis=0)
            rewards, next_obs = sample_mean[:, : rew_dim], sample_mean[:, rew_dim :]

            next_obs += obs
            returns = jnp.tile(returns, (action_dim, 1)).flatten()
            returns += gamma**k * (rewards * w).sum(axis=1)
            obs = next_obs

        trajectory_ind = returns.argmax(axis=0)
        best_action = trajectory_ind % action_dim

        return best_action, key
        
    def cgpi_action(self, obs, w):
        if self.cgpi_layer is None:
            w_p = cp.Parameter(self.phi_dim)
            alpha = cp.Variable(len(self.M))
            W_ = np.vstack(self.M)
            W = cp.Parameter(W_.shape)
            V = cp.Parameter(len(self.M))
            objective = cp.Minimize(alpha @ V)
            constraints = [alpha @ W == w_p] #, alpha >= 0]
            problem = cp.Problem(objective, constraints)
            assert problem.is_dpp()
            self.cgpi_layer = CvxpyLayer(problem, parameters=[w_p, W, V], variables=[alpha])

        M = jnp.stack(self.M + [w])
        obs_m = obs.reshape(1,-1).repeat(M.shape[0], axis=0)
        psi_values = self.psi.apply(self.psi_state.params, obs_m, M, deterministic=True)
        psi_values = psi_values.mean(axis=0)
        q_values_w = (psi_values * w.reshape(1, 1, w.shape[0])).sum(axis=2)
        q_w, qs = q_values_w[-1], q_values_w[:-1]
        lower_bound = jnp.max(qs, axis=0)

        q_values = (psi_values * M.reshape(M.shape[0], 1, M.shape[1])).sum(axis=2)
        q_values_source = q_values[:-1]

        alphas = jnp.vstack(
            self.cgpi_layer(w.astype(jnp.float64), M[:-1].astype(jnp.float64), q_values_source.astype(jnp.float64)[:,a])
            for a in range(self.action_dim)
        ).T
        c_w = M[:-1] @ jnp.tile(self.min_phi, (self.action_dim, 1)).T
        upper_bound = jnp.maximum(q_values_source * alphas, c_w * alphas).sum(axis=0)

        c_qs = jnp.maximum(q_w, lower_bound)
        c_qs = jnp.minimum(c_qs, upper_bound)

        action = c_qs.argmax()
        return action

    def eval(self, obs: np.ndarray, w: np.ndarray) -> int:
        if self.use_gpi:
            if self.gpi_type == "cgpi":
                action = self.cgpi_action(obs, w)
            elif self.gpi_type == "hgpi":
                action, self.key = USFA.hstep_gpi_action(self.psi, 
                                                         self.psi_state, 
                                                         self.dynamics.ensemble, 
                                                         self.dynamics.ensemble_state,
                                                         self.dynamics.inputs_mu,
                                                         self.dynamics.inputs_sigma,
                                                         self.dynamics.normalize_inputs,
                                                         self.dynamics.elites, 
                                                         obs, 
                                                         w, 
                                                         self.M,
                                                         self.env.spec.id, 
                                                         self.h_step, 
                                                         self.gamma,
                                                         self.key)
            elif self.gpi_type == "mpc":
                action, self.key = USFA.mpc_action(self.dynamics.ensemble, 
                                                   self.dynamics.ensemble_state, 
                                                   self.dynamics.inputs_mu,
                                                   self.dynamics.inputs_sigma,
                                                   self.dynamics.normalize_inputs,
                                                   self.dynamics.elites, 
                                                   obs, 
                                                   w, 
                                                   self.env.spec.id, 
                                                   self.h_step, 
                                                   self.gamma,
                                                   self.key)
            elif self.gpi_type == "gpi":
                if self.include_w:
                    self.M.append(w)
                action, self.key = USFA.gpi_action(self.psi, self.psi_state, obs, w, self.M, self.key)
                if self.include_w:
                    self.M.pop(-1)
        else:
            action, self.key = USFA.max_action(self.psi, self.psi_state, obs, w, self.key)

        action = jax.device_get(action)            
        return action

    def act(self, obs, w) -> int:
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        else:
            if self.use_gpi:
                action, policy_index, self.key = USFA.gpi_action(self.psi, self.psi_state, obs, w, self.M, self.key, return_policy_index=True)
                action, policy_index = jax.device_get(action), jax.device_get(policy_index)
                self.police_indices.append(policy_index)
            else:
                action, self.key = USFA.max_action(self.psi, self.psi_state, obs, w, self.key)
                action = jax.device_get(action)
            return action

    @staticmethod
    @partial(jax.jit, static_argnames=["psi"])
    def max_action(psi, psi_state, obs, w, key) -> int:
        psi_values = psi.apply(psi_state.params, obs, w, deterministic=True)
        q_values = (psi_values * w.reshape(1, w.shape[0])).sum(axis=3)
        q_values = q_values.mean(axis=0).squeeze(0)
        action = q_values.argmax()
        action = jax.device_get(action)
        return action, key

    def set_gpi_set(self, M: List[np.ndarray]):
        self.M = M.copy()

    def learn(
        self,
        total_timesteps: int,
        w: np.ndarray,
        M: List[np.ndarray],
        change_w_each_episode: bool = True,
        total_episodes: Optional[int] = None,
        reset_num_timesteps: bool = True,
        eval_env: Optional[gym.Env] = None,
        eval_freq: int = 1000,
        reset_learning_starts: bool = False,
    ):
        self.env.w = w
        self.M = M

        self.police_indices = []
        self.num_timesteps = 0 if reset_num_timesteps else self.num_timesteps
        self.num_episodes = 0 if reset_num_timesteps else self.num_episodes
        if reset_learning_starts:  # Resets epsilon-greedy exploration
            self.learning_starts = self.num_timesteps

        episode_reward = 0.0
        episode_vec_reward = np.zeros(w.shape[0])
        num_episodes = 0
        (obs, info), done = self.env.reset(), False
        for _ in range(1, total_timesteps + 1):
            if total_episodes is not None and num_episodes == total_episodes:
                break
            self.num_timesteps += 1

            if self.num_timesteps < self.learning_starts:
                action = self.env.action_space.sample()
            else:
                action = self.act(obs, w)

            next_obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated

            self.replay_buffer.add(obs, action, info["vector_reward"], next_obs, terminated)

            if self.num_timesteps >= self.learning_starts:
                if self.dyna:
                    if self.num_timesteps % self.dynamics_train_freq(self.num_timesteps) == 0:
                        m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.replay_buffer.get_all_data(max_samples=int(2e5))
                        one_hot = np.zeros((len(m_obs), self.action_dim))
                        one_hot[np.arange(len(m_obs)), m_actions.astype(int).reshape((len(m_obs)))] = 1
                        X = np.hstack((m_obs, one_hot))
                        Y = np.hstack((m_rewards, m_next_obs - m_obs))
                        mean_loss, mean_holdout_loss = self.dynamics.fit(X, Y)
                        if self.log:
                            self.writer.add_scalar("dynamics/mean_loss", mean_loss, self.num_timesteps)
                            self.writer.add_scalar("dynamics/mean_holdout_loss", mean_holdout_loss, self.num_timesteps)

                    if self.num_timesteps >= self.dynamics_rollout_starts and self.num_timesteps % self.dynamics_rollout_freq == 0:
                        self.rollout_dynamics(w)

                self.train(w)

            if eval_env is not None and self.log and self.num_timesteps % eval_freq == 0:
                total_reward, discounted_return, total_vec_r, total_vec_return = eval_mo(self, eval_env, w)
                self.writer.add_scalar("eval/total_reward", total_reward, self.num_timesteps)
                self.writer.add_scalar("eval/discounted_return", discounted_return, self.num_timesteps)
                for i in range(episode_vec_reward.shape[0]):
                    self.writer.add_scalar(f"eval/total_reward_obj{i}", total_vec_r[i], self.num_timesteps)
                    self.writer.add_scalar(f"eval/return_obj{i}", total_vec_return[i], self.num_timesteps)
                if self.dyna and self.num_timesteps >= self.dynamics_rollout_starts:
                    plot = visualize_eval_jax(self, eval_env, self.dynamics, w, compound=False, horizon=1000)
                    wb.log({"dynamics/predictions": wb.Image(plot), "global_step": self.num_timesteps})
                    plot.close()

            episode_reward += reward
            episode_vec_reward += info["vector_reward"]
            if done:
                (obs, info), done = self.env.reset(), False
                num_episodes += 1
                self.num_episodes += 1

                if num_episodes % 100 == 0:
                    print(f"Episode: {self.num_episodes} Step: {self.num_timesteps}, Ep. Total Reward: {episode_reward}")
                if self.log:
                    wb.log({"metrics/policy_index": np.array(self.police_indices), "global_step": self.num_timesteps})
                    self.police_indices = []
                    self.writer.add_scalar("metrics/episode", self.num_episodes, self.num_timesteps)
                    self.writer.add_scalar("metrics/episode_reward", episode_reward, self.num_timesteps)
                    for i in range(episode_vec_reward.shape[0]):
                        self.writer.add_scalar(f"metrics/episode_reward_obj{i}", episode_vec_reward[i], self.num_timesteps)

                episode_reward = 0.0
                episode_vec_reward = np.zeros(w.shape[0])

                if change_w_each_episode:
                    w = random.choice(M)
                    self.env.w = w
            else:
                obs = next_obs
