import jax
import jax.numpy as jnp
import time
from ppo_s5 import make_train as make_train_s5
from ppo_gru import make_train as make_train_gru
from ppo import make_train as make_train_mlp
from wrappers import AliasPrevAction
from popgym_cartpole import popgym_cartpole_params, get_env as get_env_cartpole
from popgym_pendulum import popgym_pendulum_params, get_env as get_env_pendulum
from popgym_repeat import popgym_repeat_params, get_env as get_env_repeat
import argparse



def run(num_runs, env, difficulty, arch="all"):
    print("*"*10)
    print(f"Running {num_runs} runs of {env} with difficulty {difficulty}")
    env_name = env

    if env == "cartpole":
        env, env_params = get_env_cartpole(difficulty)
    elif env == "pendulum":
        env, env_params = get_env_pendulum(difficulty)
    elif env == "repeat":
        env, env_params = get_env_repeat(difficulty)
    else:
        raise NotImplementedError

    config = {
        "LR": 5e-5,
        "NUM_ENVS": 64,
        "NUM_STEPS": 1024,
        "TOTAL_TIMESTEPS": 15e6,
        "UPDATE_EPOCHS": 30,
        "NUM_MINIBATCHES": 8,
        "GAMMA": 0.99,
        "GAE_LAMBDA": 1.0,
        "CLIP_EPS": 0.2,
        "ENT_COEF": 0.0,
        "VF_COEF": 1.0,
        "MAX_GRAD_NORM": 0.5,
        "ENV": AliasPrevAction(env),
        "ENV_PARAMS": env_params,
        "ANNEAL_LR": False,
        "DEBUG": False,
    }

    rng = jax.random.PRNGKey(42)
    train_vjit_gru = jax.jit(jax.vmap(make_train_gru(config)))
    train_vjit_s5 = jax.jit(jax.vmap(make_train_s5(config)))
    train_vjit_mlp = jax.jit(jax.vmap(make_train_mlp(config)))
    rngs = jax.random.split(rng, num_runs)
    info_dict = {}

    if arch == "all" or arch == "s5":
        t0 = time.time()
        compiled_s5 = train_vjit_s5.lower(rngs).compile()
        compile_s5_time = time.time() - t0
        print(f"s5 compile time: {compile_s5_time}")

        t0 = time.time()
        out_s5 = jax.block_until_ready(compiled_s5(rngs))
        run_s5_time = time.time() - t0
        print(f"s5 time: {run_s5_time}")
        info_dict["s5"] = {
            "compile_s5_time": compile_s5_time,
            "run_s5_time": run_s5_time,
            "out": out_s5[1],
        }

    if arch == "all" or arch == "og_s5":
        config["NO_RESET"] = True
        train_vjit_ogs5 = jax.jit(jax.vmap(make_train_s5(config)))
        t0 = time.time()
        compiled_ogs5 = train_vjit_ogs5.lower(rngs).compile()
        compile_ogs5_time = time.time() - t0
        print(f"og s5 compile time: {compile_ogs5_time}")

        t0 = time.time()
        out_ogs5 = jax.block_until_ready(compiled_ogs5(rngs))
        run_ogs5_time = time.time() - t0
        print(f"og s5 time: {run_ogs5_time}")
        info_dict["og_s5"] = {
            "compile_s5_time": compile_ogs5_time,
            "run_s5_time": run_ogs5_time,
            "out": out_ogs5[1],
        }


    if arch == "all" or arch == "gru":

        t0 = time.time()
        compiled_gru = train_vjit_gru.lower(rngs).compile()
        compile_gru_time = time.time() - t0
        print(f"gru compile time: {compile_gru_time}")

        t0 = time.time()
        out_gru = jax.block_until_ready(compiled_gru(rngs))
        run_gru_time = time.time() - t0
        print(f"gru time: {run_gru_time}")
        info_dict["gru"] = {
            "compile_gru_time": compile_gru_time,
            "run_gru_time": run_gru_time,
            "out": out_gru[1],
        }

    if arch == "all" or arch == "mlp":
    
        t0 = time.time()
        compiled_mlp = train_vjit_mlp.lower(rngs).compile()
        compile_mlp_time = time.time() - t0
        print(f"mlp compile time: {compile_mlp_time}")

        t0 = time.time()
        out_mlp = jax.block_until_ready(compiled_mlp(rngs))
        run_mlp_time = time.time() - t0
        print(f"mlp time: {run_mlp_time}")
        info_dict["mlp"] = {
            "compile_mlp_time": compile_mlp_time,
            "run_mlp_time": run_mlp_time,
            "out": out_mlp[1],
        }

    jnp.save(f"{env_name}_{difficulty}_{arch}.npy", info_dict)

parser = argparse.ArgumentParser()
parser.add_argument("--num-runs", type=int, required=True)
parser.add_argument("--env", type=str, default="repeat")
parser.add_argument("--difficulty", type=str, default="previous_hard")
parser.add_argument("--arch", type=str, default="all")
args = parser.parse_args()

if __name__ == "__main__":
    run(args.num_runs, args.env, args.difficulty, args.arch)