# The 1st step: run method on all parameters.

# %%
import sys

sys.path.append('../')

import scipy.io as sio

import torch
from Modules.ada_graph import *
from Utils.utils import set_seed, try_gpu
from Utils.run_experiments import deploy_args, train
from Utils.data_processor import select_dataset


if __name__ == "__main__":

    # set seed
    seed_num = 42
    set_seed(seed_num)
    # deploy network on multiple GPUs
    gpu_list = [0]                                
    devices = [try_gpu(i) for i in gpu_list]


    data_list = ['madelon']
    for fname in data_list:
        # try:
        print('========> '+fname,flush=True)
        fpath = '../Data/'+fname+'.mat'
        train_data,test_data,n,d,c = select_dataset(fpath)

        # set learning parameters
        args = deploy_args(fname)
        args['train_num'] = n
        args['fea_dim'] = d

        X_te,y_te = test_data
        X_test, y_test = X_te.to(torch.float32), y_te.to(torch.float32)
        X_tr,y_tr = train_data
        X_train, y_train = X_tr.to(torch.float32), y_tr.to(torch.float32)
        X_train = X_train.detach().numpy()
        X_test = X_test.detach().numpy()
        y_train = y_train.detach().numpy()
        y_test =  y_test.detach().numpy()

        # start training
        lrs = [1e-4,1e-3,1e-2,1e-1,1e0,1e1]    
        epsilons = [1e-3,1e-2,1e-1]
        if fname == 'madelon':
            feanums = [5,10,15,20] 
        else:
            feanums = [25,50,75,100,150,200,300]
        num_neighbours = [2,3,4,5]
        print('feanums:',feanums)
        print('num_neighbour:', num_neighbours)
        


        for k, num_neighbour in enumerate(num_neighbours):
            args['num_neighbours'] = num_neighbour
            ind_total = []
            S_total = []
            for feanum in feanums:
                ind_2 = []
                S_2 = []
                args['selected_num'] = feanum
                for i, lr in enumerate(lrs):
                    ind_1 = []
                    S_1 = []
                    args['lr'] = lr
                    for j,epsilon in enumerate(epsilons):
                        args['epsilon'] = epsilon
                        # pretrain parameter
                        args['pretrain_flag'] = False
                        args['pretrained_model'] = ""

                        # define netrowk
                        net = Ada_Graph_Fixed_Concrete_new(args['train_num'], args['fea_dim'], args['selected_num'], 
                                        args['num_neighbours'], args['epsilon'], args['num_iter'], 
                                        args['manual_flag'],device = devices[0])
            
                        I, S = train(net,train_data,devices, args)

                        # calcluate and sort the feature importance
                        score = I.squeeze()
                        sorted, selected_ind = torch.sort(score,descending=True)
                        res = selected_ind.cpu().detach().numpy()[:args['selected_num']]
                        ind_1.append(res)
                        S_1.append(S.cpu().detach().numpy())
                    ind_2.append(ind_1)
                    S_2.append(S_1)
                ind_total.append(ind_2)
                S_total.append(S_2)
            sio.savemat('./Results_all_params/AdaGraph_'+fname+'_num_neighbor'+str(num_neighbour)+'.mat',{'indices':ind_total, 'X_train':X_train,  'X_test':X_test, 'y_train':y_train,'y_test':y_test})
        print('Done.',flush=True)
        # except Exception as e:
        #     print('############### Error! ###################')
        #     traceback.print_exc()