from numpy import int32, asarray
import numpy as np
from omegaconf import OmegaConf
from os.path import join, dirname
import os, shutil

from math import ceil
from config_structs import Config, DataParams, ModelParams, TrainingParams, Setting, TaskConfig, TaskListConfig

CONFIG_DIR = '../conf/experiment'

BASE_DIR = '/tmp/{id}'

CONFIG_NAME = 'sweep_width_{id}.yaml'



def clear_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

if __name__ == '__main__':
    clear_folder(CONFIG_DIR)
    widths = [32, 64, 128, 256, 512]
    # width_es_map =  {512: 1, 256: 2, 128: 6, 64: 12, 32: 36, 16: 64, 8: 64, 4: 128, 2: 128}
    width_es_map =  {512: 1, 256: 2, 128: 2, 64: 4, 32: 4}
    # for each key in width_es_map, assign a list of random numbers of length width_es_map[key]
    # these will be the seeds for the ensemble members
    width_es_seeds = {512: [262415, 153293, 25454, 423887],
                    256: [558377, 698766],
                    128: [719891, 24732],
                    64: [50657],
                    32: [389667]}
    
    

    data_seed = 2423
    # seed = 3442
    dp = DataParams(data_seed=data_seed)

    tlcs = []

    for w in widths:
        for seed in width_es_seeds[w]:
        # set seed to a random int
        # seed = int(np.random.randint(0, 10**6))
            tlcs.append(TaskListConfig(task_list=[TaskConfig(training_params=TrainingParams(microbatch_size=128, use_warmup_cosine_decay=False, eta_0=6e-3), model_params=ModelParams(N=w, ensemble_size=width_es_map[w]), seed=seed)], data_params=dp))


    for w in widths:
        # for seed in width_es_seeds[w]:
        # set seed to a random int
        seed = int(np.random.randint(0, 10**6))
        tlcs.append(TaskListConfig(task_list=[TaskConfig(training_params=TrainingParams(microbatch_size=128, use_warmup_cosine_decay=False, eta_0=6e-3), model_params=ModelParams(N=w, ensemble_size=1), seed=seed)], data_params=dp))

    configs = [Config(setting=Setting(), hyperparams=tlc_inst, base_dir=BASE_DIR.format(id=id)) for id, tlc_inst in enumerate(tlcs)]
    str_configs = ['# @package _global_\n' + OmegaConf.to_yaml(conf) for conf in configs]

    curr_dir = dirname(__file__)
    config_save_folder = join(curr_dir, CONFIG_DIR)

    for id, strc in enumerate(str_configs):
        config_fname = CONFIG_NAME.format(id=id)
        config_rel_loc = join(config_save_folder, config_fname)

        with open(config_rel_loc, mode='x') as fi:
            fi.write(strc) # TODO: add file exists exception handler + clean up

