from multiprocessing import Value
import os
import re
import json
import warnings
import argparse
import numpy as np
import pandas as pd
from Schemes.hyperband import Hyperband
from Schemes.bohb import BOHB
from hpobench.benchmarks.ml.tabular_benchmark import TabularBenchmark
from utils import *


parser = argparse.ArgumentParser(description='Script description')
parser.add_argument('--scheme', type=str, default='Hyperband', help='HPO scheme: Hyperband or BOHB')
parser.add_argument('--task', type=str, default='acc_loss', help='Description of task')
parser.add_argument('--rounds', type=int, default=1000, help='Rounds of running')
parser.add_argument('--ult_objs', type=str, nargs='+', default=['test_accuracy', 'test_losses'], help='List of strings (default: ["test_accuracy", "test_losses"])')
parser.add_argument('--max_iters', type=int, nargs='+', default=[9, 27], help='List of integers (default: [20, 50, 81, 120, 150])')
parser.add_argument('--eta', type=int, default=3, help='Fraction of saving for earch stopping')
args = parser.parse_args()

max_iters = args.max_iters
ult_objs = args.ult_objs
task = args.task
rounds = args.rounds
eta = args.eta
dataset = 'Tab_NN'
scheme = args.scheme

TAB_NN = TabularBenchmark('nn', task_id=31)
config_space = TAB_NN.get_configuration_space(seed=1)
criterias, direction = get_cta_dir(task)


def get_params():
    return config_space.sample_configuration()


def try_params(n_iteration, config, criteria='valid_accuracy'):
    print("n_iteration: ", n_iteration)
    print("criteria: ", criteria)
    fidelity = {'iter': round(n_iteration)}    # Nasbench_201 fidelity range [1, 200]
    final_fidelity = {'iter': 243}    # Nasbench_201 fidelity range [1, 200]

    # If more than one seed is given, the results are averaged 
            # across the seeds but then the training time is the sum of the costs per seed.
    data_seed=8916
    if 'seed' in criteria:
        data_seed = None
        # deal with this later!!!!

    result_dic = TAB_NN.objective_function(configuration=config, fidelity=fidelity, seed=data_seed)
    final_rst_dic = TAB_NN.objective_function(configuration=config, fidelity=final_fidelity, seed=data_seed)
    rtn_dic = {'time': result_dic['cost'],
               'test_accuracy': result_dic['info'][data_seed]['test_scores']['acc'], 
               'test_losses': result_dic['info'][data_seed]['test_loss']}
    
    
    # get final validation or training accuracy & losses
    if 'valid' in criteria:
        rtn_dic['final_valid_accuracy'] = final_rst_dic['info'][data_seed]['val_scores']['acc']
        rtn_dic['final_valid_loss'] = final_rst_dic['info'][data_seed]['val_loss']
    elif 'train' in criteria:
        rtn_dic['final_train_accuracy'] = final_rst_dic['info'][data_seed]['train_scores']['acc']
        rtn_dic['final_train_loss'] = final_rst_dic['info'][data_seed]['train_loss']
   
    if criteria == 'valid_accuracy':
        rtn_dic[criteria] = result_dic['info'][data_seed]['val_scores']['acc']
    elif criteria == 'train_accuracy':
        rtn_dic[criteria] = result_dic['info'][data_seed]['train_scores']['acc']
    elif criteria == 'valid_losses':
        rtn_dic[criteria] = result_dic['info'][data_seed]['val_loss']
    elif criteria == 'train_losses':
        rtn_dic[criteria] = result_dic['info'][data_seed]['train_loss']
    else:
        rtn_dic[criteria] = result_dic['info'][criteria]

    return rtn_dic



# Initialize criteria dictionary
Ult_obj_dict = dict()
for obj in ult_objs:
    Ult_obj_dict[obj] = dict()

for max_iter in max_iters:
    print(f"################## max_iter = {max_iter} ######################")

    # Initialize criteria dictionary
    for obj in ult_objs:
        if 'accuracy' in obj:
            Ult_obj_dict[obj]['Max_test_accuracy'] = []
        if 'loss' in obj:
            Ult_obj_dict[obj]['Min_test_loss'] = []

        for criteria in criterias:
            Ult_obj_dict[obj][criteria] = []

    for r in range(rounds):
        print(f"########### round = {r} ############")

        # Run HPO scheme, the configurations are the same for testing each criteria
        if scheme == 'Hyperband':
            hb = Hyperband(get_params, try_params, max_iter=max_iter, eta=eta, skip_first=1)
        elif scheme == 'BOHB':
            hb = BOHB(config_space, get_params, try_params, max_iter=max_iter, eta=eta, skip_first=1)
        else:
            raise ValueError(f"Unkonwn HPO scheme: {scheme}")

        # Load or generate fixed configurations
        if not os.path.exists(os.path.join(f"Records/{scheme}/", dataset)):
            os.makedirs(os.path.join(f"Records/{scheme}/", dataset))
        config_dir = os.path.join(f"Records/{scheme}/", dataset, str(hb))
        if not os.path.exists(config_dir):
            os.makedirs(config_dir)
        config_dir = os.path.join(config_dir, "fixed_configs")
        
        config_file = os.path.join(config_dir, "config" + str(r) + ".json")
        config_file_exist = False
        if os.path.exists(config_file):
            hb.load_fixed_config_dict(config_file, config_space)
            config_file_exist = True
        else:
            warnings.warn(f"File {config_file} doesn't exist, generate fixed configurations.")

        for i, criteria in enumerate(criterias):
            print(f"########### criteria = {criteria} ############")
            rst = hb.run_fixed_configs(criteria=criteria, direction=direction[i])

            # Record results
            dir = os.path.join(f"Records/{scheme}/", dataset, str(hb), "config_rsts")
            if not os.path.exists(dir):
                os.makedirs(dir)
            dir = os.path.join(dir, criteria)
            if not os.path.exists(dir):
                os.makedirs(dir)
            file = os.path.join(dir, "record" + str(r) + ".csv")
            hb.record_to_csv(rst, record_file=file)

            # Get the best configuration selected by hyperband algorithm
            rst = pd.DataFrame(rst)
            best_rst = rst.iloc[-1]
            for obj in ult_objs:
                if 'accuracy' in obj:
                    Ult_obj_dict[obj][criteria].append(best_rst['test_accuracy'])
                if 'loss' in obj:
                    Ult_obj_dict[obj][criteria].append(best_rst['test_losses'])

            if i == 0:
                # Must get fixed configuration after one round of hyperband
                if config_file_exist == False:
                    fixed_config_dic = hb.get_fixed_config_dict(config_space)
                    if not os.path.exists(config_dir):
                        os.makedirs(config_dir)
                    # Write configurations to file
                    with open(config_file, "w") as json_file:
                        json.dump(fixed_config_dic, json_file)

                # Get overall statistics of current group of configuration
                for obj in ult_objs:
                    if 'accuracy' in obj:
                        Ult_obj_dict[obj]['Max_test_accuracy'].append(rst['test_accuracy'].max())
                    if 'loss' in obj:
                        Ult_obj_dict[obj]['Min_test_loss'].append(rst['test_losses'].min())
        
        for obj in ult_objs:
            warnings.warn(f'Ult_obj_dict[{obj}] = {Ult_obj_dict[obj]}')
    # Store obtained test values from different criterias into file 
    for obj in ult_objs:
        dir = os.path.join(f"Records/{scheme}/", dataset, str(hb), "cta")
        if not os.path.exists(dir):
            os.makedirs(dir)
        dir = os.path.join(dir, f"obj_{obj}")
        if not os.path.exists(dir):
            os.makedirs(dir)
        file = os.path.join(dir, f"{task}.csv")
        warnings.warn(f"file = {file}")
        df = pd.DataFrame(Ult_obj_dict[obj])
        df.to_csv(file, index=False)
