import sys
import os
import argparse
import datetime

import numpy as np

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import time
import copy

import utils.utils as utils
import nns

def weight_init(layer):
    if isinstance(layer, torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(layer.weight)
        try:
            layer.bias.data.fill_(0.01)
        except:
            pass
    elif isinstance(layer, torch.nn.BatchNorm2d):
        layer.weight.data.fill_(1.0)
        layer.bias.data.zero_()
    elif isinstance(layer, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(layer.weight)
        layer.bias.data.fill_(0.01)

def HSIC(K, L):
    assert K.shape[0] == K.shape[1]
    assert L.shape[0] == L.shape[1]
    assert K.shape[0] == L.shape[0]
    n = K.shape[0]
    H = (torch.eye(n) - torch.ones(n, n)/n).to(K.device)
    return torch.trace(K @ H @ L @ H) / ((n-1) ** 2)

def CKA(K, L):
    assert K.shape == L.shape
    return HSIC(K, L) / ((HSIC(K, K) ** 0.5) * (HSIC(L, L) ** 0.5)) 

if __name__ == "__main__":

    global device, loss_fn, train_x_agg, train_y_agg, test_x, test_y, public_x, public_y

    time_ = datetime.datetime.now()
    name = f"_{time_.month}-{time_.day}-{time_.hour}-{time_.minute}-{time_.second}"
    
    arg_command = sys.argv[1:]
    parser = argparse.ArgumentParser()

    # general
    parser.add_argument("--cuda", type=int, default=0) # start gpu number to use
    parser.add_argument("--ngpu", type=int, default=1) # nb of gpus
    parser.add_argument("--ver", type=int, default=-1)
    parser.add_argument("--load", type=str, default=None)
    parser.add_argument("--load_kernel", type=str, default=None)
    parser.add_argument("--load_kl", type=str, default=None)
    parser.add_argument("--load_inv", type=str, default=None)
    parser.add_argument("--save_kernel", action="store_true")
    parser.add_argument("--save_inv", action="store_true")

    # dataset and settings
    parser.add_argument("--data", type=str, default='datasets/toy3_50_hetero')
    parser.add_argument("--nn_type", type=str, nargs="+", default=None)
    parser.add_argument("--nn_ratio", type=float, nargs="+", default=1.)

    # configuration for kernel learning
    parser.add_argument("--kl_lr", type=float, default=1e-4) # learning rate for kernel learning
    parser.add_argument("--kl_epochs", type=int, default=100) # epochs for kernel learning
    parser.add_argument("--bs_k", type=int, default=32) # batch size for kernel learning 
    
    # configuration for collaborative learning
    parser.add_argument("--epochs", type=int, default=50) # communication rounds for collaborative learning
    parser.add_argument("--optim", type=str, default="Adam") # optimizer for collaborative learning
    parser.add_argument("--local_epochs", type=int, default=50) # local epochs for collaborative learning
    parser.add_argument("--lr", type=float, default=5e-2) # base learning rate for collaborative learning
    parser.add_argument("--br", type=int, default=1) # batch ratio for collaborative learning (local)
    
    FLAGS, _ = parser.parse_known_args(arg_command)

    # data type and client numbers
    algorithm = 'DCL_NN'
    data_type = FLAGS.data.split('/')[1].split('_')[0]
    client_num = int(FLAGS.data.split('/')[1].split('_')[1])
    
    # logger
    utils.generate_dir('logs')
    FLAGS.log_fn = f"logs/log{name}.txt"
    logger = utils.init_logger(FLAGS.log_fn)
    utils.log_arguments(logger, FLAGS)

    if (FLAGS.load == None) and (FLAGS.load_kl == None):
        utils.log_msg(logger, "No load.. You must load pretrained models" + "\n")
        exit()

    # device and version setting
    device = []
    for i in range(client_num):
        d_ = int((FLAGS.cuda + i % FLAGS.ngpu) % torch.cuda.device_count())
        device.append(f'cuda:{d_}')
    if FLAGS.ver == -1:
        ver = FLAGS.cuda
    else:
        ver = FLAGS.ver

    # loss function
    loss_fn = nn.MSELoss()#.to(device)

    # data load
    public_x, public_y = torch.load(FLAGS.data + '/0/train.pt')
    test_x, test_y = torch.load(FLAGS.data + '/0/test.pt')
    train_x = []
    train_y = []
    for i in range(1, client_num + 1):
        x_, y_ = torch.load(FLAGS.data + f'/{i}/train.pt')
        train_x.append(x_)
        train_y.append(y_)

    train_x_agg = torch.cat(train_x)
    train_y_agg = torch.cat(train_y)

    data_num_ratio = np.array([len(train_xp) for train_xp in train_x])
    data_num_ratio = data_num_ratio / np.sum(data_num_ratio)

    # network construction
    if FLAGS.nn_type == None:
        FLAGS.nn_ratio = [0.3, 0.3, 0.2, 0.2]
        if data_type == "toy3":
            FLAGS.nn_type = ["FNN4_32", "FNN4_64", "FNN5_32", "FNN3_64"]
        elif data_type == "energy":
            FLAGS.nn_type = ["FNN_ENERGY4_32", "FNN_ENERGY4_64", "FNN_ENERGY5_32", "FNN_ENERGY3_64"]
        elif data_type == "mnist":
            FLAGS.nn_type = ["ResNet18_MNIST", "ResNet34_MNIST", "MobileNetv2_MNIST", "ResNet50_MNIST"]
        elif data_type == "utk":
            FLAGS.nn_type = ["CNN1_UTK", "CNN2_UTK", "CNN3_UTK", "CNN4_UTK"]
        elif data_type == "imdb":
            FLAGS.nn_type = ["ResNet18_IMDB", "ResNet34_IMDB", "MobileNetv2_IMDB", "ResNet50_IMDB"]
    else:
        if type(FLAGS.nn_ratio) != list:
            FLAGS.nn_ratio = [FLAGS.nn_ratio]
        if type(FLAGS.nn_type) != list:
            FLAGS.nn_type = [FLAGS.nn_type]
    nn_num = [int(client_num * r) for r in FLAGS.nn_ratio]
    nets = []
    hidden_layer_num = 0
    for num, net in zip(nn_num, FLAGS.nn_type):
        for _ in range(num):
            if net[:10] == "FNN_ENERGY":
                num_layer = int(net[10:].split("_")[0])
                hidden_units = int(net[10:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units, data = "energy"))
                hidden_layer_num += hidden_units
            elif net[:3] == "FNN":
                num_layer = int(net[3:].split("_")[0])
                hidden_units = int(net[3:].split("_")[1])
                nets.append(nns.FNN(num_layer, hidden_units))
                hidden_layer_num += hidden_units
            elif net == "CNN1_UTK":
                nets.append(nns.CNN1_UTK())
                hidden_layer_num += 64
            elif net == "CNN2_UTK":
                nets.append(nns.CNN2_UTK())
                hidden_layer_num += 64
            elif net == "CNN3_UTK":
                nets.append(nns.CNN3_UTK())
                hidden_layer_num += 64
            elif net == "CNN4_UTK":
                nets.append(nns.CNN4_UTK())
                hidden_layer_num += 64
            elif net == "ResNet18_MNIST":
                nets.append(nns.ResNet18_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet34_MNIST":
                nets.append(nns.ResNet34_MNIST())
                hidden_layer_num += 512
            elif net == "ResNet50_MNIST":
                nets.append(nns.ResNet50_MNIST())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_MNIST":
                nets.append(nns.MobileNetv2_MNIST())
                hidden_layer_num += 1280
            elif net == "ResNet18_IMDB":
                nets.append(nns.ResNet18_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet34_IMDB":
                nets.append(nns.ResNet34_IMDB())
                hidden_layer_num += 512
            elif net == "ResNet50_IMDB":
                nets.append(nns.ResNet50_IMDB())
                hidden_layer_num += 2048
            elif net == "MobileNetv2_IMDB":
                nets.append(nns.MobileNetv2_IMDB())
                hidden_layer_num += 1280

    dt = TensorDataset(test_x, test_y)
    dlt = DataLoader(dt, batch_size = 1000, shuffle=False, drop_last=False)

    for dr, net, d_ in zip(data_num_ratio, nets, device):
        net.data_ratio = dr
        net.device = d_
    
    if FLAGS.load_kl == None:
        # load pretrained model
        utils.log_msg(logger, "Load Pretrained Neural Networks.." + "\n")
        params = torch.load(FLAGS.load, map_location = torch.device('cpu'))
        for net, param in zip(nets, params):
            net.load_state_dict(param)
            net = net.to(net.device)

        # test pretrained model
        with torch.no_grad():
            for net in nets:
                net.eval()
            MSE_list = []
            dt = TensorDataset(test_x, test_y)
            dtl = DataLoader(dt, batch_size = 1000, shuffle = False, drop_last = False)
            for net in nets:
                pretrain_test_ = 0
                for x, y in dtl:
                    output = net(x.to(net.device))["output"].reshape(-1)
                    pretrain_test_ += loss_fn(output, y.to(net.device)).detach().cpu().numpy() * len(x)
                pretrain_test_ = pretrain_test_ / len(test_x)
                MSE_list.append(pretrain_test_)
        utils.log_msg(logger, f"Pretrained Local Model Performance (Avg) : MSE {np.array(MSE_list).mean()} RMSE {(np.array(MSE_list) ** 0.5).mean()}")

        ### kernel learning
        utils.log_msg(logger, "Start Kernel Learning.." + "\n")
        for net in nets:
            net.eval()
        # To compute the target kernel values
        if FLAGS.load_kernel == None:
            with torch.no_grad():
                kernel_t = torch.zeros(len(public_x), len(public_x))
                for net in nets:
                    dp = TensorDataset(public_x)
                    dpl = DataLoader(dp, batch_size = 1000, shuffle = False, drop_last = False)
                    o = []
                    for x in dpl:
                        o_ = net(x[0].to(net.device), feature = True)["feature"].cpu().detach()
                        o.append(o_)
                    o = torch.cat(o, dim = 0)
                    o = o @ o.T
                    kernel_t += net.data_ratio * o
            if FLAGS.save_kernel:
                utils.generate_dir('kernel_save')
                kernel_dir = ('kernel_save/' + FLAGS.data.split('/')[1])
                utils.generate_dir(kernel_dir)
                torch.save(kernel_t, kernel_dir + f'/{algorithm}_{data_type}_{client_num}_{ver}_kernel.pt')
        else:
            kernel_t = torch.load(FLAGS.load_kernel, map_location = torch.device('cpu'))

        # kernel learning
        for j in range(len(nets)):
            t = time.time()
            utils.log_msg(logger, f"Kernel Train for {j+1}th Client")
            nets[j].apply(weight_init)
            nets[j].train()
            optimizer = torch.optim.Adam(params = nets[j].parameters(), lr = FLAGS.kl_lr, weight_decay = 5e-4)
            d = TensorDataset(public_x, torch.tensor([i for i in range(len(public_x))]))
            dl = DataLoader(d, batch_size = FLAGS.bs_k, shuffle = True, drop_last = True)
            for _ in range(FLAGS.kl_epochs):
                for x, i in dl:
                    optimizer.zero_grad()
                    o = nets[j](x.to(nets[j].device), feature = True)["feature"]
                    kernel_s = o @ o.T
                    loss = - CKA(kernel_s, kernel_t[i][:,i].to(nets[j].device))
                    loss.backward()
                    optimizer.step()
            nets[j].eval()

        # save models after kernel learning
        utils.generate_dir('model_save')
        saving_dir = 'model_save/' + FLAGS.data.split('/')[1]
        utils.generate_dir(saving_dir)
        torch.save([copy.deepcopy(net.state_dict()) for net in nets], saving_dir + f'/{algorithm}_{data_type}_{client_num}_{ver}_kl.pt')
        utils.log_msg(logger, "End Kernel Learning.." + "\n")
    else:
        utils.log_msg(logger, "Load Neural Networks (trained by kernel learning).." + "\n")
        params = torch.load(FLAGS.load_kl, map_location=torch.device('cpu'))
        for net, param in zip(nets, params):
            net.load_state_dict(param)
            net = net.to(net.device)
    
    for net in nets:
        net.eval()
        
    # initialize the last layer
    for net in nets:
        for p in net.feature_layers.parameters():
            p.requires_grad = False
        net.linear.weight.data = torch.zeros_like(net.linear.weight.data)
        net.linear.bias.data = torch.zeros_like(net.linear.bias.data)

    if FLAGS.load_inv == None:
        HSICs = []
        dp = TensorDataset(public_x)
        dpl = DataLoader(dp, batch_size = 1000, shuffle = False, drop_last = False)
    
        # compute HSIC of feature kernels for learning rate scaling
        with torch.no_grad():
            for net in nets:
                o = []
                for x in dpl:
                    o_ = net(x[0].to(net.device), feature = True)["feature"].cpu()
                    o.append(o_)
                o = torch.cat(o, dim = 0)
                k = o @ o.T
                HSICs.append(HSIC(k, k).cpu().numpy())
        HSIC_max = np.max(np.array(HSICs))
        for HSIC_, net in zip(HSICs, nets):
            net.HSIC = (HSIC_max / HSIC_) ** 0.5
    
        # compute inverse of Gram matrix for training public data
        inv_list = []
        with torch.no_grad():
            d = TensorDataset(public_x)
            dl = DataLoader(d, batch_size = 500, shuffle = False, drop_last = False)
            for net in nets:
                k = []
                f_ = []
                for x in dl:
                    x = x[0].to(net.device)
                    feature = net(x, feature = True)['feature']
                    feature = torch.cat([feature, torch.ones(len(x), 1).to(net.device)], dim = 1)
                    k.append(feature.T @ feature)
                    f_.append(feature.T)
                inv_list.append(torch.linalg.pinv(torch.stack(k).sum(0)) @ torch.cat(f_, dim = 1))
        if FLAGS.save_inv:
            utils.generate_dir('inv_save')
            inv_dir = 'inv_save/' + FLAGS.data.split('/')[1]
            utils.generate_dir(inv_dir)
            torch.save([[net.HSIC for net in nets], inv_list], inv_dir + f'/{algorithm}_{data_type}_{client_num}_{ver}_inv.pt')
    else:
        HSICs, inv_list = torch.load(FLAGS.load_inv, map_location=torch.device('cpu'))
        for HSIC_, net in zip(HSICs, nets):
            net.HSIC = HSIC_
        for i in range(len(nets)):
            inv_list[i] = inv_list[i].to(nets[i].device)

    # collaborative learning
    utils.log_msg(logger, "Start Collaborative Learning Phase.." + "\n")
    for t in range(1, FLAGS.epochs + 1):
        public_pred = torch.zeros_like(public_y)
        if FLAGS.optim == "Adam":
            optimizers = [torch.optim.Adam(params = [net.linear.weight, net.linear.bias], lr = FLAGS.lr) for net in nets]
        elif FLAGS.optim == "SGD":
            optimizers = [torch.optim.SGD(params = [net.linear.weight, net.linear.bias], lr = FLAGS.lr, momentum = 0., weight_decay = 0.) for net in nets]

        for net, train_xp, train_yp, optimizer in zip(nets, train_x, train_y, optimizers):
            d = TensorDataset(train_xp, train_yp)
            dl = DataLoader(d, batch_size = int(len(train_xp)/FLAGS.br), shuffle = True, drop_last = True)
            net.train()
            # local training
            for _ in range(FLAGS.local_epochs):
                for x, y in dl:
                    optimizer.zero_grad()
                    x = x.to(net.device)
                    y = y.to(net.device)
                    out = net(x)["output"]
                    loss = loss_fn(out.reshape(-1), y) * net.HSIC
                    loss.backward()
                    optimizer.step()
            net.eval()

            # aggregate predictions to consensus prediction
            dp = TensorDataset(public_x)
            dpl = DataLoader(dp, batch_size = 500, shuffle = False, drop_last = False)
            with torch.no_grad():
                o = []
                for x in dpl:
                    o_ = net(x[0].to(net.device))["output"].cpu().reshape(-1)
                    o.append(o_)
                public_pred += net.data_ratio * torch.cat(o, dim = 0)#.to(device)
        
        # training public data 
        for net, inv_list_ in zip(nets, inv_list):
            net.eval()
            coeff = (inv_list_ @ public_pred.unsqueeze(1).to(net.device)).T
            net.linear.weight.data = coeff[:,:-1]
            net.linear.bias.data = coeff[:,-1]

        if t % 5 == 0:
            MSE_list =  []
            for net in nets:
                net.eval()
                mse = 0
                with torch.no_grad():
                    for x, y in dlt:
                        mse += len(x) * loss_fn(net(x.to(net.device))['output'].reshape(-1).detach().cpu(), y).numpy()
                MSE_list.append(float(mse) / len(test_x))
                net.train()
            utils.log_msg(logger, f"Epoch {t} Test Loss : MSE {np.array(MSE_list).mean()} RMSE {(np.array(MSE_list) ** 0.5).mean()}")