import os
import numpy as np
import torch
import cdro.utils as utils
import cdro.data_utils as data_utils
import cdro.train_classifiers as train_classifiers
import cdro.scm_datasets as scm_datasets
import datetime


def run_benchmark(seed):
    """ Run the benchmarking experiments.
    inputs:
        --seed: the random seed to use for the experiment
    outputs:
        train the decision models and save the results
    """
    """
    trainers includes the following:
    - ERM
    - AL
    - ROSS
    - CDRO

    
    datasets includes the following:
    - adult
    - compas
    - lin
    """
    seed = seed
    trainers = ['ERM', 'DRO', 'AL', 'ROSS']
    datasets = ['adult', 'compas','lin']

    dirs_2_create = [utils.model_save_dir, utils.metrics_save_dir, utils.scms_save_dir]
    for directory in dirs_2_create:
        if not os.path.exists(directory):
            os.makedirs(directory)

    # ------------------------------------------------------------------------------------------------------------------
    #                                       LEARN THE STRUCTURAL CAUSAL MODELS
    # ------------------------------------------------------------------------------------------------------------------

    learned_scms = {'adult': scm_datasets.Learned_Adult_SCM, 'compas': scm_datasets.Learned_COMPAS_SCM}

    for dataset in datasets:
        if dataset in learned_scms.keys():
            print('Fitting SCM for %s...' % dataset)

            # Learn a single SCM (no need for multiple seeds)
            np.random.seed(seed)
            torch.manual_seed(seed)

            X, _, _ = data_utils.process_data(dataset)
            myscm = learned_scms[dataset](linear=False)
            if os.path.exists(utils.scms_save_dir + dataset):
                myscm.fit_eqs(X.to_numpy(), save=utils.scms_save_dir + dataset)

    # ------------------------------------------------------------------------------------------------------------------
    #                                TRAIN THE DECISION MODELS
    # ------------------------------------------------------------------------------------------------------------------
    print("Starting Simulation: ", datetime.datetime.now())
    for trainer in trainers:
        for dataset in datasets:

            # Choose model type based on dataset
            model_type = utils.get_model(dataset)

            # set lambda coefficient equal to 1 for all models
            lambd = 1
            # set number of training epochs equal to 10 for all models
            if dataset == 'adult':
                train_epochs = 10
            else:
                train_epochs = 10

            save_dir = utils.get_model_save_dir(dataset, trainer, model_type, seed, lambd)
            save_name = utils.get_metrics_save_dir(dataset, trainer, lambd, model_type, 0, seed)
            print('Training... %s %s %s' % (model_type, trainer, dataset))


            mean_acc, mcc_max, uai_05, uai_01, uai_cf, uai_ar_05, uai_ar_01 = \
                    train_classifiers.train(dataset, trainer, model_type, train_epochs, lambd, seed,
                                            verbose=False, save_dir=save_dir)

            print(mean_acc)
            # Save the results
            np.save(save_name + '_accs.npy', np.array([mean_acc]))
            np.save(save_name + '_mccs.npy', np.array([mcc_max]))
            np.save(save_name + '_uai_05.npy', np.array([uai_05]))
            np.save(save_name + '_uai_01.npy', np.array([uai_01]))
            np.save(save_name + '_uai_cf.npy', np.array([uai_cf]))
            np.save(save_name + '_uai_ar_05.npy', np.array([uai_ar_05]))
            np.save(save_name + '_uai_ar_01.npy', np.array([uai_ar_01]))
            print("Finished at: ", datetime.datetime.now())


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    run_benchmark(args.seed)