from BinaryClassification import *
import argparse
import torch
import pandas as pd
import os


def get_dict_average(dict_list):
    list_dict = {k: list() for k in dict_list[0].keys()}
    for d in dict_list:
        for k, v in d.items():
            list_dict[k].append(v)
    average_dict = {k: np.mean(np.array(v, dtype=object), axis=0) for k, v in list_dict.items()}
    return average_dict 


def load_data(dataset_name, path="./datasets/benchmark_atk/data/"):
    ins = torch.load(path+"ins_%s.pt"%(dataset_name))
    label = torch.load(path+"lable_%s.pt"%(dataset_name))
    return ins, label


def train_with_1strategy(X, Y, max_iterations=50000, k=10, loss_type="logit", aggregate_loss_type="stk", stop_early=True, verbose=False, lr=0.1, delta=0.001, l2reg_c=1e1):
    # train
    nn = Fc2(X=X, Y=Y, hidden_dim=10, k=k, loss_type=loss_type, aggregate_loss_type=aggregate_loss_type, stop_early=stop_early, verbose=verbose, lr=lr, delta=delta, l2reg_c=l2reg_c); 
    result_dict = nn.gd(max_iterations=max_iterations)
    return result_dict




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--loss_type", default="logit", choices=["logit", "hinge"])
    parser.add_argument("--aggregate_loss_type", default="stk", choices=["average", "maximum", "atk", "matk", "stk"])
    parser.add_argument("--batchsize", default=200)
    parser.add_argument("--max_iterations", default=500)
    parser.add_argument("--delta", default=0.001)
    parser.add_argument("--tolerance", default=20)
    parser.add_argument("--dataset_name", default="german", choices=["appendicitis", 
                                                                     "wisconsin", 
                                                                     "australian", 
                                                                     "german", 
                                                                     "titanic", 
                                                                     "phoneme", 
                                                                     "spambase",
                                                                     "segment0",
                                                                     "page-blocks0"])
    parser.add_argument("--stop_early", default="1")
    parser.add_argument("--verbose", default="0")
    args = parser.parse_args()

    if os.path.exists("./%s_acc_score_%s.csv"%(args.loss_type, args.max_iterations)):
        acc_score_dataframe = pd.read_csv("./%s_acc_score_%s.csv"%(args.loss_type, args.max_iterations), index_col=0)
    else:
        acc_score_dataframe = pd.DataFrame(index=["appendicitis", "wisconsin", "australian", "german", "titanic", "phoneme", "spambase","segment0", "page-blocks0"], 
                                            columns=["average", "maximum", "atk", "matk", "stk"])

    if np.isnan(acc_score_dataframe[args.aggregate_loss_type][args.dataset_name]):
        
        ins, label = load_data(dataset_name=args.dataset_name)
        ins = (ins-torch.mean(ins, dim=0))/torch.std(ins, dim=0)
        X, Y = ins.numpy().T, label.numpy().T

        # train.
        best_result_dict = {"acc_score_test": 0}
        if "atk" in args.aggregate_loss_type:
            k_range = np.linspace(0.01, 0.81, 9) * Y.shape[-1]
        else:
            k_range = [0, ]
        if "smoothing" in args.aggregate_loss_type:
            delta_range = [1e-1, 1e-2, 1e-3, 1e-4]
        else:
            delta_range = [0, ]
        l2reg_c_range = [1e2, 1e3, 1e4]
        lr_range = [1e-3, 1e-1, 5e-2, 1e-2, 5e-3]
        for l2reg_c in l2reg_c_range:
            for lr in lr_range:
                print(lr, l2reg_c)
                for k in k_range:
                    for delta in delta_range:
                        result_dict_list = list()
                        for i in range(int(args.tolerance)):
                            result_dict = train_with_1strategy(X=X, Y=Y,
                                                                max_iterations=int(args.max_iterations),
                                                                loss_type=args.loss_type,  
                                                                aggregate_loss_type=args.aggregate_loss_type,
                                                                stop_early=int(args.stop_early), 
                                                                verbose=int(args.verbose),
                                                                lr=float(lr), 
                                                                delta=delta, 
                                                                l2reg_c=float(l2reg_c),
                                                                k=int(k))
                            result_dict_list.append(result_dict)
                        result_dict = get_dict_average(result_dict_list)
                        # result_dict["acc_score_var"] = np.std([d["acc_score"] for d in result_dict_list])
                        # result_dict["iterations_var"] = np.std([d["iterations"] for d in result_dict_list])
                        # result_dict["t_var"] = np.std([d["t"] for d in result_dict_list])
                        if best_result_dict["acc_score_test"] < result_dict["acc_score_test"]:
                            best_result_dict = result_dict
        # save_record.
        for k in best_result_dict.keys():
            some_result = best_result_dict[k]
            # save single-value result.0.
            if type(some_result) != np.ndarray:
                if os.path.exists("./%s_%s_%s.csv"%(args.loss_type, k, args.max_iterations)):
                    k_dataframe = pd.read_csv("./%s_%s_%s.csv"%(args.loss_type, k, args.max_iterations), index_col=0)
                else:
                    k_dataframe = pd.DataFrame(index=["appendicitis", "wisconsin", "australian", "german", "titanic", "phoneme", "spambase", "segment0", "page-blocks0"], 
                                                columns=["average", "maximum", "atk", "matk", "stk"])
                k_dataframe[args.aggregate_loss_type][args.dataset_name] = some_result
                k_dataframe.to_csv("./%s_%s_%s.csv"%(args.loss_type, k, args.max_iterations))
            # save vector result.
            else:
                if not os.path.exists("./records/%s/"%k):
                    os.mkdir("./records/%s/"%k)
                file_path = "./records/%s/%s_%s_%s.csv"%(k, args.loss_type, args.aggregate_loss_type, args.dataset_name)
                np.savetxt(file_path, some_result, fmt="%f")