# The 4th step: We use the selected features for downstream tasks to evalued the performance of selected features.
# Run after Random_run_with_best_params.py

import scipy.io as sio
import numpy as np

import sys
sys.path.append('../')
from Utils.model_evaluation import run_baseline,check_best
from Utils.data_processor import select_dataset
import torch


if __name__ == "__main__":
    
    random_iter = 1
    device = torch.device('cpu')

    tasks = ['classification','clustering','Reconstruction']
    baseline_models = ['RF','KMEANS','NN']
    # tasks = ['Reconstruction']
    # baseline_models = ['NN']
    datasets = ['madelon']


    for task, baseline_model in zip(tasks,baseline_models):
        print('===========>'+task+':',flush=True)
        if task in ['clustering','classification']:
            higher_flag = True
        else:
            higher_flag = False
        
        for fname in datasets:
            print('========> '+fname, flush=True)
            fpath = '../Data/'+fname+'.mat'

            # load optimal parameters
            if fname == 'madelon':
                feanums = [5,10,15,20] 
            else:
                feanums = [25,50,75,100,150,200,300]
            print('feanums:',feanums,flush=True)
            data = sio.loadmat('./Results_opt_para/AdaGraph_'+fname+'.mat')
            indices = data['indices']
            
            # evaluate features based on the baseline model with repeated trials
            iter_res = [[] for _ in range(len(feanums))]
            for iter in range(10):
                
                train_data,test_data,n,d,c = select_dataset(fpath)
                X_te,y_te = test_data
                X_test, y_test = X_te.to(torch.float32), y_te.to(torch.float32)
                X_tr,y_tr = train_data
                X_train, y_train = X_tr.to(torch.float32), y_tr.to(torch.float32)

                X_train = X_train.detach().numpy()
                X_test = X_test.detach().numpy()
                y_train = y_train.detach().numpy()
                y_test =  y_test.detach().numpy()
                # extract selected features
                X_train_ori = X_train
                X_test_ori = X_test 

                # start evaluation
                total_res_mean = []
                total_res_std = []
                best_params = []
                for i,numfea in enumerate(feanums):
                    selected_ind = indices[iter][i].squeeze()
                    X_train = X_train_ori[:,selected_ind]
                    X_test = X_test_ori[:,selected_ind]
                    
                    if task in ['clustering','classification']:
                        res = run_baseline(X_train,X_test,y_train,y_test,baseline_model,iter)
                    else:
                        res = run_baseline(X_train,X_test,X_train_ori,X_test_ori,baseline_model, iter, device = device)
                    iter_res[i].append(res)

            total_res_mean = [np.mean(iter_res[i]) for i in range(len(feanums))]
            total_res_std = [np.std(iter_res[i]) for i in range(len(feanums))]
            sio.savemat('./Results_downstream_tasks/'+baseline_model+'_'+fname+'_Our.mat',{'iter_res':iter_res,'total_res_mean':total_res_mean,'total_res_std':total_res_std})

