from SyntheticClassification import *
from data.DataGenerator import *
import argparse
import torch
import pandas as pd
import os


def save_data(data, label, usefor="train"):
    torch.save(data, f"./data/data_{usefor}.pt")
    torch.save(label, f"./data/lable_{usefor}.pt")


def gen_data(usefor, n=2500, positive_r=.5):
    mu_negative = np.zeros(200)
    mu_positive = np.array([*[1 for i in range(10)], *[0 for i in range(190)]])
    Spl = np.fromfunction(lambda i, j: np.power(0.8, abs(i-j)), shape=(200, 200))
    n_positive, n_negative = int(n*positive_r), int(n*(1-positive_r))
    data_positive = multivariate_normal(mu_positive, cov=Spl, size=n_positive)
    data_negative = multivariate_normal(mu_negative, cov=Spl, size=n_negative)
    label1 = np.ones(shape=(data_positive.shape[0], 1))
    label0 = np.zeros(shape=(data_negative.shape[0], 1))
    data, label = np.concatenate((data_positive, data_negative)), np.concatenate((label1, label0))
    save_data(data, label, usefor=usefor)
    return data, label



def get_dict_average(dict_list):
    list_dict = {k: list() for k in dict_list[0].keys()}
    for d in dict_list:
        for k, v in d.items():
            list_dict[k].append(v)   
    for k, v in list_dict.items():
        if type(v[0]) in [np.ndarray, list]:
            list_dict[k] = [vv[:len(min(v, key=len))] for vv in v]
    average_dict = {k: np.mean(np.array(v), axis=0) for k, v in list_dict.items() if k not in ["W", "b"]}
    std_dict = {k+"_std": np.std(np.array(v), axis=0) for k, v in list_dict.items() if "mis" in k}
    result_dict = {**average_dict, **std_dict}
    result_dict["W"] = list_dict["W"][np.argmax(list_dict["acc_score_val"])]
    # result_dict["b"] = list_dict["b"][np.argmax(list_dict["acc_score_val"])]
    return result_dict

def load_data(path):
    X_train = torch.load(path+"data_train.pt")
    Y_train = torch.load(path+"lable_train.pt")
    X_val = torch.load(path+"data_val.pt")
    Y_val = torch.load(path+"lable_val.pt")
    X_test = torch.load(path+"data_test.pt")
    Y_test = torch.load(path+"lable_test.pt")
    return X_train.T, Y_train.T, X_val.T, Y_val.T, X_test.T, Y_test.T

def train_with_1strategy(data_path="./data/", 
                         max_epochs=50000, 
                         k=10, 
                         lr=0.1, 
                         delta=0.01, 
                         l2reg_c=1e6, 
                         loss_type="cross_entropy", 
                         aggregate_loss_type="matk",
                         smooth_method="srelu", 
                         stop_early=True, 
                         verbose=False, 
                         batch_size=1, 
                         optimizer="gd"):
    # train
    X_train, Y_train, X_val, Y_val, X_test, Y_test = load_data(path=data_path)
    nn = Fc(X_train=X_train, Y_train=Y_train, X_val=X_val, Y_val=Y_val, X_test=X_test, Y_test=Y_test, 
            k=k, 
            loss_type=loss_type, 
            aggregate_loss_type=aggregate_loss_type, 
            smooth_method=smooth_method,
            stop_early=stop_early, 
            verbose=verbose, 
            batch_size=batch_size, 
            lr=lr, 
            delta=delta, 
            l2reg_c=l2reg_c, 
            optimizer=optimizer
            )
    result_dict = nn.gd(max_epochs=max_epochs)
    return result_dict




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--loss_type", default="cross_entropy", choices=["cross_entropy"])
    parser.add_argument("--smooth_method", default="none", choices=["none", "srelu", "softplus", "leaky_relu", "swish", "elu"])
    parser.add_argument("--aggregate_loss_type", default="average", choices=["average", "maximum", "atk", "matk", "smooth_matk", "sgd_matk", "smooth_sgd_matk"])
    parser.add_argument("--batchsize", default=512)
    parser.add_argument("--max_epochs", default=100)
    parser.add_argument("--delta", default=0.001)
    parser.add_argument("--tolerance", default=50)
    parser.add_argument("--positive_r", default=0.8, type=float)
    parser.add_argument("--dataset_name", default="synthetic_data", choices=["synthetic_data"])
    parser.add_argument("--stop_early", default=1, type=int)
    parser.add_argument("--verbose", default=1, type=int)
    args = parser.parse_args()

    gen_data(usefor="train", n=10000, positive_r=args.positive_r)
    gen_data(usefor="val", n=2500, positive_r=args.positive_r)
    gen_data(usefor="test", n=2500, positive_r=args.positive_r)

    if os.path.exists("./%s_acc_score_%s.csv"%(args.loss_type, args.max_epochs)):
        acc_score_dataframe = pd.read_csv("./%s_acc_score_%s.csv"%(args.loss_type, args.max_epochs), index_col=0)
    else:
        acc_score_dataframe = pd.DataFrame(index=["synthetic_data"], 
                                            columns=["average", "maximum", "atk", "matk", "smooth_matk", "sgd_matk", "smooth_sgd_matk"])

    if True: # np.isnan(acc_score_dataframe[args.aggregate_loss_type][args.dataset_name]):

        # train.
        best_result_dict = {"acc_score_val": 0}
        if "sgd" in args.aggregate_loss_type:
            # k_range = np.linspace(0.28, 0.81, 6) * args.batchsize * 0.5
            k_range = [0.8*args.batchsize]
        else:
            if 'atk' in args.aggregate_loss_type:
                k_range = [0.8*10000]
            else:
                k_range = [0]
        if "smooth" in args.aggregate_loss_type:
            if args.smooth_method == "srelu":
                delta_range = [0.05]  # [1e-1, 1e-2, 1e-3]
            elif args.smooth_method == "softplus":
                delta_range = [1]
            elif args.smooth_method == "leaky_relu":
                delta_range = [0.01]
            elif args.smooth_method == "swish":
                delta_range = [10]
            elif args.smooth_method == "elu":
                delta_range = [1]
        else:
            delta_range = [0, ]
        l2reg_c_range = [1e6]  # [1e3, 5e3, 1e4]
        lr_range = [0.001]  # [0.02, 0.04, 0.06]
        if "sgd" in args.aggregate_loss_type:
            optimizer = "adam"
        else:
            optimizer = "gd"
        for l2reg_c in l2reg_c_range:
            for lr in lr_range:
                for k in k_range:
                    for delta in delta_range:
                        print(args.aggregate_loss_type, args.smooth_method, lr, l2reg_c, k, delta)
                        result_dict_list = list()
                        for i in range(int(args.tolerance)):
                            result_dict = train_with_1strategy(data_path="./data/", 
                                                                max_epochs=int(args.max_epochs),
                                                                loss_type=args.loss_type,  
                                                                smooth_method=args.smooth_method,
                                                                aggregate_loss_type=args.aggregate_loss_type,
                                                                stop_early=int(args.stop_early), 
                                                                verbose=int(args.verbose),
                                                                lr=float(lr), 
                                                                delta=delta, 
                                                                l2reg_c=float(l2reg_c),
                                                                k=int(k),
                                                                batch_size=512,
                                                                optimizer=optimizer)
                            result_dict_list.append(result_dict)
                        result_dict = get_dict_average(result_dict_list)
                        # result_dict["acc_score_var"] = np.std([d["acc_score"] for d in result_dict_list])
                        # result_dict["epochs_var"] = np.std([d["epochs"] for d in result_dict_list])
                        # result_dict["t_var"] = np.std([d["t"] for d in result_dict_list])
                        if best_result_dict["acc_score_val"] <= result_dict["acc_score_val"]:
                            best_result_dict = result_dict
        # save_record.
        for k in best_result_dict.keys():
            some_result = best_result_dict[k]
            # save single-value result.0.
            if type(some_result) != np.ndarray:
                if os.path.exists("./records/%s_%s_%s_%s_%s.csv"%(args.loss_type, args.smooth_method, k, args.max_epochs, args.positive_r)):
                    k_dataframe = pd.read_csv("./records/%s_%s_%s_%s_%s.csv"%(args.loss_type, args.smooth_method, k, args.max_epochs, args.positive_r), index_col=0)
                else:
                    k_dataframe = pd.DataFrame(index=["synthetic_data"], 
                                                columns=["average", "maximum", "atk", "matk", "smooth_matk", "sgd_matk", "smooth_sgd_matk"])
                k_dataframe[args.aggregate_loss_type][args.dataset_name] = some_result
                k_dataframe.to_csv("./records/%s_%s_%s_%s_%s.csv"%(args.loss_type, args.smooth_method, k, args.max_epochs, args.positive_r))
            # save vector result.
            else:
                if not os.path.exists("./records/%s/"%k):
                    os.mkdir("./records/%s/"%k)
                file_path = "./records/%s/%s_%s_%s_%s_%s.csv"%(k, args.loss_type, args.aggregate_loss_type, args.smooth_method, args.dataset_name, args.positive_r)
                np.savetxt(file_path, some_result, fmt="%f")