from lqr_env import LQREnv, RandController, LQRController
import lqr_env
import bandit_env
import numpy as np
import os
import pickle
from IPython import embed
from collect_data import rollin, rollin_bandit, generate_histories, generate_bandit_histories, generate_bandit_histories_for_arms, generate_topk_bandit_histories
from collect_data import generate_dr_histories, generate_dr_histories_for_goals, generate_dr_stitch_histories_for_goals, generate_dr_permuted_histories_for_indices, rollin_dr





if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=100, help="Envs")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top k subset")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Bandit arm variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--mode", type=int, required=False, default=0, help="Mode")
    parser.add_argument('--rollin', type=str, required=False, default="uniform", help="Whether to collect eval trajs in train tasks")
    parser.add_argument('--collect_in_train_tasks', default=False, action='store_true', help="Whether to collect eval trajs in train tasks")

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    H = args['H']
    envname = args['env']
    var = args['var']
    cov = args['cov']

    dx = args['dim']
    du = args['dim']
    dim = args['dim']
    k = args['k']
    mode = args['mode']
    rollin = args['rollin']
    train = args['collect_in_train_tasks']
    
    if envname == 'bandit' and mode == 1:
        trajs = generate_bandit_histories_special(n_envs, 1, 1, H, dim, var=var, cov=cov)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'

    # trajs = generate_histories(n_envs, 1, 1, H)
    if envname == 'bandit':  
        trajs = generate_bandit_histories(n_envs, 1, 1, H, dim, var=var, cov=cov)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'

    elif envname == 'bandit_thompson':
        trajs = generate_bandit_histories(n_envs, 1, 1, H, dim, var=var, cov=cov, type='bernoulli')
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'
    
    elif envname == 'bandit_ood':
        envs = list(range(dim // 2, dim)) * (n_envs // (dim // 2))
        trajs = generate_bandit_histories_for_arms(envs, 1, 1, H, dim, var=var, cov=cov)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_cov{cov}.pkl'



    elif envname == 'bandit_topk':
        trajs = generate_topk_bandit_histories(n_envs, 1, 1, H, dim, k=k, var=var)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}_var{var}_k{k}.pkl'

    elif envname == 'darkroom':
        trajs = generate_dr_histories(n_envs, 1, 1, H, dim)
        filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'
    
    elif envname == 'darkroom_heldout':
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        traj_str = 'expert_trajs' if rollin == 'expert' else 'trajs'

        if train:
            train_goals = goals[:train_test_split]
            train_goals = np.array(train_goals.tolist() * int(n_envs // len(train_goals)))

            trajs = generate_dr_histories_for_goals(train_goals, 1, 1, H, dim, rollin=rollin)
            traj_str += '_train'
        else:
            test_goals = goals[train_test_split:]
            test_goals = np.array(test_goals.tolist() * int(n_envs // len(test_goals)))

            trajs = generate_dr_histories_for_goals(test_goals, 1, 1, H, dim, rollin=rollin)
            traj_str += '_eval'

        filepath = f'datasets/{traj_str}_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'

    elif envname == 'darkroom_stitch':
        goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        test_goals = np.repeat(goals, n_envs // len(goals), axis=0)
        trajs = generate_dr_stitch_histories_for_goals(test_goals, 1, 1, H, dim, eval=not train)
        if train:
            filepath = f'datasets/trajs_train_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'
        else:
            filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'

    elif envname == 'darkroom_permuted':
        indices = np.arange(120)    # 5! permutations in darkroom
        np.random.RandomState(seed=0).shuffle(indices)
        train_test_split = int(.8 * len(indices))
        if train:
            train_indices = indices[train_test_split:]
            train_indices = np.array(train_indices.tolist() * int(n_envs // (0.8 * len(indices))))
            trajs = generate_dr_permuted_histories_for_indices(train_indices, 1, 1, H, dim)
            filepath = f'datasets/trajs_train_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'
        else:
            test_indices = indices[train_test_split:]
            test_indices = np.array(test_indices.tolist() * int(n_envs // (0.2 * len(indices))))
            trajs = generate_dr_permuted_histories_for_indices(test_indices, 1, 1, H, dim)
            filepath = f'datasets/trajs_eval_{envname}_envs{n_envs}_H{H}_d{dim}.pkl'

    else:
        raise NotImplementedError       

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    with open(filepath, 'wb') as file:
        pickle.dump(trajs, file)
    
    print(f"Saved to {filepath}.")
