import fire
import gym
import mo_gymnasium as mo_gym
import numpy as np

import wandb as wb
from gpi.successor_features.gpi import GPI
from gpi.successor_features.tabular_sf import SF
from gpi.utils.wrappers import RandomAction


def best_vector(values, w):
    max_v = values[0]
    for i in range(1, len(values)):
        if values[i] @ w > max_v @ w:
            max_v = values[i]
    return max_v


def run(seed: int, alpha: float = 0.1, timesteps_per_iter: int = 1000000, stochastic: bool = True):

    if stochastic:
        env = mo_gym.LinearReward(RandomAction(mo_gym.make("four-room-v0")))
        eval_env = mo_gym.LinearReward(RandomAction(mo_gym.make("four-room-v0")))
    else:
        env = mo_gym.LinearReward(mo_gym.make("four-room-v0"))
        eval_env = mo_gym.LinearReward(mo_gym.make("four-room-v0"))

    max_iter = 4

    def agent_constructor():
        return SF(
                env,
                alpha=alpha,
                gamma=0.95,
                initial_epsilon=1.0,
                final_epsilon=0.05,
                epsilon_decay_steps=timesteps_per_iter//2,
                dyna=True,
                per=True,
                use_gpi=True,
                dyna_updates=5,
                dyna_deterministic_dynamics=False,
                min_priority=0.0001,
                alpha_per=0.6,
                buffer_size=500000,
                log=False,
            )
    gpi_agent = GPI(env, agent_constructor, log=True, project_name="h-GPI", experiment_name=f"gpi N={timesteps_per_iter} stochastic={stochastic} seed={seed}")

    for iter in range(1, max_iter + 1):
        w = np.zeros(env.reward_dim)
        w[iter - 1] = 1.0

        print('Next weight vector:', w)

        gpi_agent.learn(
            total_timesteps=timesteps_per_iter,
            use_gpi=True,
            w=w,
            eval_env=eval_env,
            eval_freq=1000,
            reset_num_timesteps=False,
            new_policy=True,
            reuse_value_ind=None,
            reset_learning_starts=True,
        )

        gpi_agent.save(f"weights/gpi_fourroom_stochastic={stochastic}_{seed}_iter={iter}")

    gpi_agent.close_wandb()


if __name__ == "__main__":
    fire.Fire(run)
