# The 3rd step: based on the best params, we random split the dataset for several times, and record the indices for downstream tasks.
# Run after Determine_param_with_ACC.py
# %%
import os
import sys

import traceback
sys.path.append('../')

import scipy.io as sio

import torch
from Modules.ada_graph import *
from Utils.utils import try_gpu
from Utils.run_experiments import deploy_args, train
from Utils.data_processor import select_dataset


if __name__ == "__main__":

    # deploy network on GPUs
    gpu_list = [0]                            
    devices = [try_gpu(i) for i in gpu_list]


    data_list = ['madelon']
    for fname in data_list:
        # try:
        print('========> '+fname,flush=True)
        fpath = '../Data/'+fname+'.mat'

        # set learning parameters
        args = deploy_args(fname)

        # load optimal parameters
        if fname == 'madelon':
            feanums = [5,10,15,20] 
        else:
            feanums = [25,50,75,100,150,200,300]
        params_data = sio.loadmat('./Selected_optimal_params/Best_para_RF_Our_'+fname+'_all.mat')
        best_params = params_data['best_params']
        
        ind_total = []
        S_total = []
        for iter in range(10):
            print('=====> '+ str(iter),flush=True)
            
            train_data,test_data,n,d,c = select_dataset(fpath)
            args['train_num'] = n
            args['fea_dim'] = d

            ind_1 = []
            S_1 = []
            for i,feanum in enumerate(feanums):
                args['selected_num'] = feanum
                args['num_neighbors'], args['lr'], args['epsilon'] = best_params[i]
                args['num_neighbors'] = round(args['num_neighbors'])

                # define netrowk
                net = Ada_Graph_Fixed_Concrete_new(args['train_num'], args['fea_dim'], args['selected_num'], 
                                args['num_neighbors'], args['epsilon'], args['num_iter'], 
                                args['manual_flag'],device = devices[0])

                I, S = train(net,train_data,devices, args)

                # calcluate and sort the feature importance
                score = I.squeeze()
                sorted, selected_ind = torch.sort(score,descending=True)
                res = selected_ind.cpu().detach().numpy()[:args['selected_num']]

                ind_1.append(res)
                S_1.append(S.cpu().detach().numpy())
            ind_total.append(ind_1)
            S_total.append(S_1)
            sio.savemat('./Results_opt_para/AdaGraph_'+fname+'.mat',{'S':S_total,'indices':ind_total})
        print('Done.',flush=True)
        # except Exception as e:
        #     print('############### Error! ###################')
        #     traceback.print_exc()