import argparse
import glob
import importlib
import os, datetime
import pickle
import sys
import random
import tqdm
import gym
import numpy as np
import torch as th
import yaml
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.env_util import make_vec_env
import utils.import_envs  # noqa: F401 pylint: disable=unused-import
from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams
from utils.exp_manager import ExperimentManager
from utils.utils import StoreDict
from mixing_times_metrics import MixingTimeAgent


def get_args():
    parser = argparse.ArgumentParser()
    # rl-baselines-zoo specific arguments

    parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1")
    parser.add_argument(
        "-f", "--folder", help="Log folder", type=str, default="rl-trained-agents"
    )
    parser.add_argument(
        "--algo",
        help="RL Algorithm",
        default="ppo",
        type=str,
        required=False,
        choices=list(ALGOS.keys()),
    )
    parser.add_argument(
        "-n", "--n-timesteps", help="number of timesteps", default=1000, type=int
    )
    parser.add_argument(
        "--num-threads",
        help="Number of threads for PyTorch (-1 to use default)",
        default=-1,
        type=int,
    )
    parser.add_argument("--n-envs", help="number of environments", default=1, type=int)
    parser.add_argument(
        "--exp-id",
        help="Experiment ID (default: 0: latest, -1: no exp folder)",
        default=0,
        type=int,
    )
    parser.add_argument(
        "--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int
    )
    parser.add_argument(
        "--no-render",
        action="store_true",
        default=False,
        help="Do not render the environment (useful for tests)",
    )
    parser.add_argument(
        "--deterministic",
        action="store_true",
        default=False,
        help="Use deterministic actions",
    )
    parser.add_argument(
        "--load-best",
        action="store_true",
        default=False,
        help="Load best model instead of last model if available",
    )
    parser.add_argument(
        "--load-checkpoint",
        type=int,
        help="Load checkpoint instead of last model if available, "
        "you must pass the number of timesteps corresponding to it",
    )
    parser.add_argument(
        "--load-last-checkpoint",
        action="store_true",
        default=False,
        help="Load last checkpoint instead of last model if available",
    )
    parser.add_argument(
        "--stochastic",
        action="store_true",
        default=False,
        help="Use stochastic actions",
    )
    parser.add_argument(
        "--norm-reward",
        action="store_true",
        default=False,
        help="Normalize reward if applicable (trained with VecNormalize)",
    )
    parser.add_argument(
        "--start-from-latest-checkpoint",
        action="store_true",
        default=False,
        help="Start from the latest checkpoint",
    )
    parser.add_argument(
        "--reward-log", help="Where to log reward", default="", type=str
    )
    parser.add_argument(
        "--gym-packages",
        type=str,
        nargs="+",
        default=[],
        help="Additional external Gym environment package modules to import (e.g. gym_minigrid)",
    )
    parser.add_argument(
        "--env-kwargs",
        type=str,
        nargs="+",
        action=StoreDict,
        help="Optional keyword argument to pass to the env constructor",
    )

    # Mixing time specific arguments
    parser.add_argument(
        "--seed", help="Random generator seed", type=int, default=0
    )
    parser.add_argument(
        "--tau", help="Task switch time", type=int, default=0
    )
    parser.add_argument(
        "--x", help="Task switch exponent", type=int, default=3
    )
    parser.add_argument(
        "--example", help="Which example?", type=str, default="example_3"
    )
    parser.add_argument(
        "--n-tasks", help="How many tasks", type=int, default=7
    )
    parser.add_argument(
        "--env-transition-type", help="Env transition type", type=str, default="random"
    )
    parser.add_argument(
        "--max-start-states", help="Number of start states", type=int, default=int(1e3)
    )
    parser.add_argument(
        "--save-path", help="Path to the results log dir", type=str, default="/resultsv2/"
    )
    parser.add_argument(
        "--asymptotic-steps",
        help="Number of asymptotic states",
        type=int,
        default=int(1e6),
    )
    parser.add_argument(
        "--reporting",
        help="Number of asymptotic states",
        type=int,
        default=int(1e4),
    )
    parser.add_argument(
        "--frequency",
        help="Frequency of updates",
        type=int,
        default=int(1e2),
    )
    parser.add_argument(
        "--use-uniform-policy",
        action="store_true",
        default=False,
        help="Use uniform policies",
    )
    parser.add_argument(
        "--only-accumulate-returns",
        action="store_true",
        default=False,
        help="Only run accumulate_returns",
    )
    parser.add_argument(
        "--only-asymptotic-returns",
        action="store_true",
        default=False,
        help="Only run asymptotic_returns",
    )
    args = parser.parse_args()
    return args


def main(args, directory):  # noqa: C901
    # Going through custom gym packages to let them register in the global registory
    for env_module in args.gym_packages:
        importlib.import_module(env_module)
    # Try for 4 atari games for now "PongNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4", "BeamRiderNoFrameskip-v4"
    set_random_seed(args.seed)
    all_env_ids = [
        "BreakoutNoFrameskip-v4",
        "PongNoFrameskip-v4",
        "SpaceInvadersNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4",
        "EnduroNoFrameskip-v4",
        "SeaquestNoFrameskip-v4",
        "QbertNoFrameskip-v4"
    ]
    if args.example == "example_2":
        with open("task_idxs.pkl", "rb") as f:
            indices = pickle.load(f)

        env_ids = []
        for id in indices[args.seed][args.n_tasks]:
            env_ids.append(all_env_ids[id])
    else:
        env_ids = all_env_ids
    envs = []
    models = []

    set_random_seed(args.seed)
    random.shuffle(all_env_ids)
    for env_id in env_ids:
        algo = args.algo
        folder = args.folder
        if args.exp_id == 0:
            args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
            print(f"Loading latest experiment, id={args.exp_id}")

        # Sanity checks
        if args.exp_id > 0:
            log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}")
        else:
            log_path = os.path.join(folder, algo)

        assert os.path.isdir(log_path), f"The {log_path} folder was not found"

        found = False
        if not args.use_uniform_policy:
            for ext in ["zip"]:
                model_path = os.path.join(log_path, f"{env_id}.{ext}")
                print(model_path)
                found = os.path.isfile(model_path)
                if found:
                    break

            if args.load_best:
                model_path = os.path.join(log_path, "best_model.zip")
                found = os.path.isfile(model_path)

            if args.load_checkpoint is not None:
                model_path = os.path.join(
                    log_path, f"rl_model_{args.load_checkpoint}_steps.zip"
                )
                found = os.path.isfile(model_path)

            if args.load_last_checkpoint:
                checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip"))
                if len(checkpoints) == 0:
                    raise ValueError(
                        f"No checkpoint found for {algo} on {env_id}, path: {log_path}"
                    )

                def step_count(checkpoint_path: str) -> int:
                    # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path
                    return int(checkpoint_path.split("_")[-2])

                checkpoints = sorted(checkpoints, key=step_count)
                model_path = checkpoints[-1]
                found = True

            if not found:
                raise ValueError(
                    f"No model found for {algo} on {env_id}, path: {model_path}"
                )

            print(f"Loading {model_path}")

        # Off-policy algorithm only support one env for now
        off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]

        if algo in off_policy_algos:
            args.n_envs = 1

        if args.num_threads > 0:
            if args.verbose > 1:
                print(f"Setting torch.num_threads to {args.num_threads}")
            th.set_num_threads(args.num_threads)

        is_atari = ExperimentManager.is_atari(env_id)

        stats_path = os.path.join(log_path, env_id)
        hyperparams, stats_path = get_saved_hyperparams(
            stats_path, norm_reward=args.norm_reward, test_mode=True
        )

        # load env_kwargs if existing
        env_kwargs = {}
        args_path = os.path.join(log_path, env_id, "args.yml")
        if os.path.isfile(args_path):
            with open(args_path, "r") as f:
                loaded_args = yaml.load(
                    f, Loader=yaml.UnsafeLoader
                )  # pytype: disable=module-attr
                if loaded_args["env_kwargs"] is not None:
                    env_kwargs = loaded_args["env_kwargs"]
        # overwrite with command line arguments
        if args.env_kwargs is not None:
            env_kwargs.update(args.env_kwargs)

        log_dir = args.reward_log if args.reward_log != "" else None

        env = create_test_env(
            env_id,
            n_envs=args.n_envs,
            stats_path=stats_path,
            seed=args.seed,
            log_dir=log_dir,
            should_render=not args.no_render,
            hyperparams=hyperparams,
            env_kwargs=env_kwargs,
        )
        envs.append(env)

        kwargs = dict(seed=args.seed)
        if algo in off_policy_algos:
            # Dummy buffer size as we don't need memory to enjoy the trained agent
            kwargs.update(dict(buffer_size=1))

        # Check if we are running python 3.8+
        # we need to patch saved model under python 3.6/3.7 to load them
        newer_python_version = (
            sys.version_info.major == 3 and sys.version_info.minor >= 8
        )

        custom_objects = {}
        if newer_python_version:
            custom_objects = {
                "learning_rate": 0.0,
                "lr_schedule": lambda _: 0.0,
                "clip_range": lambda _: 0.0,
            }
        if not args.use_uniform_policy:
            model = ALGOS[algo].load(
                model_path, env=env, custom_objects=custom_objects, **kwargs
            )
            models.append(model)

    mixing_time = MixingTimeAgent(
        args,
        envs,
        models,
        is_atari=True,
        tau=args.tau,
        asymptotic_steps=args.asymptotic_steps,
        max_start_states=args.max_start_states,
        directory=directory,
    )
    mixing_time.run()
    [envs[idx].close() for idx in range(len(envs))]


if __name__ == "__main__":
    args = get_args()
    c = 1
    d = 10
    x = args.x
    tau = args.tau
    # tasks = [2, 3, 4]
    # task_id = int(os.environ['SLURM_PROCID'])
    # args.n_tasks = tasks[task_id]
    if args.example == "example_2":
        save_path = args.save_path + 'example_2/{}/n_{}/seed_{}'.format(args.algo, args.n_tasks, args.seed)
    elif args.example == "example_3":
        save_path = args.save_path + 'example_3/{}/tau_{}/seed_{}'.format(args.algo, tau, args.seed)
    results_directory = save_path
    if not os.path.isdir(results_directory):
        os.makedirs(results_directory)

    args.folder = "rl-trained-agents/"
    args.n_envs = 1
    args.tau = tau
    logs = "Running {} experiment with Tau = {} and Tasks = {}".format(args.example, tau, args.n_tasks)
    with open(os.path.join(results_directory, "log.txt"), "a") as f:
        f.write(logs)
    main(args, results_directory)
