import torch
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.conAR_dec_beta import conAR_dec
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix

print(torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

print('testing')
print(torch.__version__)


# train_sample_num means the first fidelity's train samples
# dec_rate means the way to make the data low
# fidelity_num means use how many fidelity to train&test
def data_model(data_name, 
               train_begin_index = 0, 
               test_begin_index = 0,
               train_samples_num = 16,
               test_samples_num = 128,
               dec_rate = 0.5,
               fidelity_num = 5,
               seed = 1,
               need_inerp = True):
    
    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)

    '''initiate the numbers'''
    train_begin_index = train_begin_index
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num

    train_num = [int(train_samples_num * pow(dec_rate, i)) for i in range(fidelity_num)]
    xte = xte[0][test_begin_index:test_samples_num]
    

    '''train model'''
    m_fid = conAR_dec(xtr, ytr, xte,
                    train_begin_index = 0, 
                    train_num = train_num, 
                    fidelity_num = fidelity_num,
                    niteration = 200,
                    learning_rate = 0.02,
                    seed = seed,
                    normal_y_mode = 0)
    yte_mean, yte_var = m_fid.train_mod()
    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]

    cp_metrics = calculate_metrix(y_test = yte_test, y_mean_pre = yte_mean, y_var_pre = yte_var)
    print("loss of our model:", cp_metrics)

    return cp_metrics
    



