from util import *
import os
import joblib
from method import *
from scorer import *
from compute import *
from scipy.stats import spearmanr
from main_experiment import *
from ACIC_generate import *
import copy


def generate_datasets(data_dir, nonlinear_fun_list, nonlinearity_y_list, rho_list, xi_list, mis_ratio_list, M):
    if not os.path.isdir('./data'):
        os.mkdir('./data')
    for nonlinear_fun in nonlinear_fun_list:
        for nonlinearity_y in nonlinearity_y_list:
            for rho in rho_list:
                for xi in xi_list:
                    for mis_ratio in mis_ratio_list:
                        generation_dir = './data/'+nonlinear_fun + '_' +str(nonlinearity_y) + '_' + str(rho)+ '_' + str(xi) +'_'+ str(mis_ratio)
                        if not os.path.isdir(generation_dir):
                            os.mkdir(generation_dir)
                        acic_generate(data_dir=data_dir, output_dir=generation_dir, xi=xi, rho=rho, n_exp=M, test_ratio=0.3, val_ratio=0.3,
                                      nonlinear_fun=nonlinear_fun, nonlinearity_y=nonlinearity_y, mis_ratio=mis_ratio)

def run_all_exps(M, nonlinear_fun_list, nonlinearity_y_list, rho_list,xi_list, mis_ratio_list,
                 train_base_model_list, train_learner_list, val_base_model_list, val_learner_list, pseudo_estimator_list, other_scorer_list):
    if not os.path.isdir('./result'):
        os.mkdir('./result')
    for nonlinear_fun in nonlinear_fun_list:
        for nonlinearity_y in nonlinearity_y_list:
            for rho in rho_list:
                for xi in xi_list:
                    for mis_ratio in mis_ratio_list:
                        data_dir ='./data/'+nonlinear_fun + '_' +str(nonlinearity_y)+ '_' + str(rho)+ '_' + str(xi) + '_'+str(mis_ratio)
                        target_result_dir = './result/'+ nonlinear_fun + '_' +str(nonlinearity_y)+ '_' + str(rho)+ '_' + str(xi) + '_'+str(mis_ratio)
                        if not os.path.isdir(target_result_dir):
                            os.mkdir(target_result_dir)
                        run_experiment(M, data_dir, target_result_dir, train_base_model_list,
                                       train_learner_list, val_base_model_list, val_learner_list, pseudo_estimator_list, other_scorer_list)
if __name__ == '__main__':
    train_base_model_list = ['lr', 'rf', 'svm']

    train_learner_list = ['S', 'PS', 'T', 'X', 'IPW', 'DR', 'R', 'RA']

    val_base_model_list = ['xgb']
    val_learner_list = ['S', 'PS', 'T', 'X', 'IPW', 'DR', 'R', 'RA']


    pseudo_estimator_list = ['DR', 'R', 'IF']
    other_scorer_list = ['random', 'fact', 'knn', 'KL']


    nonlinearity_y_list = [1]
    nonlinear_fun_list = ['power']
    rho_list = [0, 0.1, 0.3]
    xi_list = [2]
    mis_ratio_list = [0]
    M = 100

    generate_datasets(data_dir=r'./ACIC2016',
                      nonlinear_fun_list=nonlinear_fun_list,
                      nonlinearity_y_list=nonlinearity_y_list, rho_list=rho_list, xi_list=xi_list, mis_ratio_list=mis_ratio_list, M=M)
    run_all_exps(M, nonlinear_fun_list, nonlinearity_y_list, rho_list, xi_list, mis_ratio_list,
                 train_base_model_list, train_learner_list, val_base_model_list, val_learner_list, pseudo_estimator_list, other_scorer_list)