import pandas as pd
import numpy as np
import time
import os

from model_mis.conAR_mis import data_model as conAR
from model_mis.resgp_mis import data_model_resgp as resgp
from model_mis.ar_mis import ar as ar
 
'''initial setting'''
# 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
data_name_list = ['plasmonic2_MF'] 
model_list = {'conAR': conAR, 'resgp': resgp, 'ar': ar}
model_name = ['conAR', 'ar', 'resgp']

seed = [0, 1, 2, 3, 4]
fidelity_num = 5
mis_rate = 0.75

for data_name in data_name_list:
        train_sample_num = 128
        for k in seed:
        # for initial_fid_sample_num in [32]:
                recording = {}
                for i in model_name:
                        recording[i] = {'train_sample_num':[], 'rmse':[], 'r2':[], 'nll':[], 'nrmse':[], 'time':[]}

                for _name in model_name:
                        for initial_fid_sample_num in [32,64,96,128]:
                                # initial random mask
                                mask_matrix = []
                                for fid in range(fidelity_num):
                                        mask_tem = np.zeros(train_sample_num)
                                        ones_num = int(initial_fid_sample_num * pow(mis_rate, fid))
                                        mask_tem[:ones_num] = 1
                                        np.random.seed(k * fidelity_num + fid)
                                        np.random.shuffle(mask_tem)
                                        mask_matrix.append(mask_tem)

                                model = model_list[_name]
                                T1 = time.time()

                                mod = model(data_name,
                                        mask = mask_matrix,
                                        train_begin_index = 0, 
                                        test_begin_index = 0,
                                        train_samples_num = train_sample_num, 
                                        test_samples_num = 128, 
                                        fidelity_num = fidelity_num,
                                        seed = k,
                                        need_inerp = True)
                                T2 = time.time()
                                
                                recording[_name]['train_sample_num'].append(initial_fid_sample_num)
                                recording[_name]['rmse'].append(mod['rmse'])
                                recording[_name]['r2'].append(mod['r2'])
                                recording[_name]['nll'].append(mod['nll'])
                                recording[_name]['nrmse'].append(mod['nrmse'])
                                recording[_name]['time'].append(T2 - T1)

                        path_csv = os.path.join( 'exp', str(_name), data_name, 'mis_'+ str(mis_rate))
                        if not os.path.exists(path_csv):
                                os.makedirs(path_csv)

                        data = {'train_sample_num': recording[_name]['train_sample_num'], 
                        'rmse': recording[_name]['rmse'],
                        'nrmse': recording[_name]['nrmse'], 
                        'r2': recording[_name]['r2'], 
                        'nll': recording[_name]['nll'], 
                        'time': recording[_name]['time']
                        }
                        df = pd.DataFrame(data)
                        df.to_csv(path_csv + '/result_' + str(k) + '.csv', index = False)