import fire
import gym
import mo_gymnasium as mo_gym
import numpy as np
from jax import config
config.update("jax_enable_x64", True)
import wandb as wb
from gpi.successor_features.usfa_jax import USFA
from gpi.utils.eval import policy_evaluation_mo
from gpi.utils.utils import equally_spaced_weights
import envs.mo_push
from envs.mo_push.mo_push import FetchDiscreteAction, FetchObservationWrapper
from rliable import library as rly
from rliable import metrics
from copy import deepcopy
import itertools


def run(num_seeds: int = 20, reps: int = 1, equally_spaced_tasks: bool = False, weights_path: str = "."):

    def make_env(record_video=False):
        env = mo_gym.make("mo-fetch-push-dense-v2")
        env = FetchDiscreteAction(env)
        env = FetchObservationWrapper(env)
        env = mo_gym.LinearReward(env)
        #if record_video:
        #    env = RecordVideo(env, "videos/reacher/", episode_trigger=lambda e: e % 1 == 0)
        return env
    
    eval_env = make_env()

    def model_training_schedule(timestep):
        if timestep < 100000:
            return 250
        else:
            return 250

    agent = USFA(
        eval_env,
        num_nets=1,
        max_grad_norm=None,
        learning_rate=3e-4,
        gamma=0.9,
        batch_size=128,
        net_arch=[256, 256, 256, 256],
        buffer_size=int(2e5),
        initial_epsilon=1,
        final_epsilon=0.05,
        epsilon_decay_steps=40000,
        learning_starts=100,
        alpha_per=0.6,
        min_priority=0.01,       
        per=True,
        drop_rate=0.01,
        layer_norm=True,
        use_gpi=True,
        h_step=1,
        gpi_type='gpi',
        gradient_updates=10,
        target_net_update_freq=200,
        tau=1,
        dyna=True,
        dynamics_normalize_inputs=False,
        dynamics_uncertainty_threshold=1.5,
        dynamics_net_arch=[400, 400, 400, 400],
        dynamics_buffer_size=int(1e5),
        dynamics_rollout_batch_size=25000,
        dynamics_train_freq=model_training_schedule,
        dynamics_rollout_freq=250,
        dynamics_rollout_starts=5000,
        dynamics_rollout_len=1,
        real_ratio=0.5,
        log=True,
        project_name="FetchPush - GPI",
        experiment_name=f"USFA - evaluation",
    )

    if equally_spaced_tasks:
        test_tasks = equally_spaced_weights(dim=eval_env.reward_dim, num_weights=64)
    else:
        test_tasks = list([np.array(x, dtype=np.float32) for x in itertools.product([1, -1], repeat=4)])
    scores = {}
    def add_score(scores, score, name):
        if name not in scores:
            scores[name] = score.reshape(1, -1)
        else:
            scores[name] = np.vstack((scores[name], score))

    def normalize_scores(score_dict):
        all_scores = np.vstack(list(score_dict.values()))
        min_score = np.min(all_scores, axis=0)
        max_score = np.max(all_scores, axis=0)
        for k, v in score_dict.items():
            score_dict[k] = (v - min_score) / (max_score - min_score)

    for seed in range(1, num_seeds+1):
        agent.load(f"{weights_path}/weights/usfa-push-{seed}/")

        agent.gpi_type = 'gpi'
        agent.include_w = True
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "GPI-ST")

        agent.gpi_type = 'gpi'
        agent.include_w = False
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "GPI-S")

        agent.gpi_type = 'cgpi'
        agent.min_phi = np.array([-0.7693, -0.8736, -0.7693, -0.8736], dtype=np.float32) / (1.0 - agent.gamma)
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "CGPI")

        agent.gpi_type = 'hgpi'
        agent.h_step = 1
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "1-GPI")

        agent.gpi_type = 'hgpi'
        agent.h_step = 2
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "2-GPI")

        agent.gpi_type = 'hgpi'
        agent.h_step = 3
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "3-GPI")

        agent.gpi_type = 'hgpi'
        agent.h_step = 4
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "4-GPI")

        agent.gpi_type = 'hgpi'
        agent.h_step = 5
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "5-GPI")

        agent.gpi_type = 'mpc'
        agent.h_step = 5
        score = np.array([policy_evaluation_mo(agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
        add_score(scores, score, "5-MPC")

        score_dict = deepcopy(scores)
        normalize_scores(score_dict)
        algorithms = list(score_dict.keys())
        aggregate_func = lambda x: np.array([
                                    metrics.aggregate_median(x),
                                    metrics.aggregate_iqm(x),
                                    metrics.aggregate_mean(x),
                                    ])
        aggregate_scores, aggregate_score_cis = rly.get_interval_estimates(score_dict, aggregate_func, reps=50000)
        for k, v in score_dict.items():
            wb.log({f"{k} - Median": aggregate_scores[k][0], "seeds": seed})
            wb.log({f"{k} - IQM": aggregate_scores[k][1], "seeds": seed})
            wb.log({f"{k} - Mean": aggregate_scores[k][2], "seeds": seed})

        for k, v in scores.items():
            np.save(f"./results/usfa-equally_spaced_tasks={equally_spaced_tasks}-fetchpush-{k}", v)


if __name__ == "__main__":
    fire.Fire(run)
