import sys
import os
import argparse

import copy
import math

import numpy as np
import torch

def generate_dir(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)
    return

# kernel for Toy-1D
def k1(x1, x2):
    if len(x1.shape) == 0:
        x1 = x1.unsqueeze(0)
    if len(x2.shape) == 0:
        x2 = x2.unsqueeze(0)
    n1 = len(x1)
    n2 = len(x2)
    x1 = x1.unsqueeze(1)
    x2 = x2.unsqueeze(0)
    return torch.minimum(x1, x2) #+ 1

# kernel for Toy-3D
def k3(x1, x2):
    if len(x1.shape) == 1:
        x1 = x1.unsqueeze(0)
    if len(x2.shape) == 1:
        x2 = x2.unsqueeze(0)
    x1 = x1.unsqueeze(1)
    x2 = x2.unsqueeze(0)
    h = torch.norm(x1-x2, dim = 2)
    return torch.where(h>1, 0, (1-h) ** 2)

# central kernel ridge regression
def centralKRR(train_x, train_y, public_x, test_x, FLAGS):
    if FLAGS.data == "toy1":
        k = k1
    elif FLAGS.data == "toy3":
        k = k3
    elif FLAGS.data == "toy3_2":
        k = k3_2
    train_x_agg = torch.cat(train_x, dim = 0).to(FLAGS.device)
    train_y_agg = torch.cat(train_y, dim = 0).to(FLAGS.device)
    k_tx = k(test_x.to(FLAGS.device), train_x_agg)
    k_xx = k(train_x_agg, train_x_agg)
    output = (k_tx @ torch.linalg.solve(k_xx + len(k_xx) * FLAGS.lamda * torch.eye(len(k_xx)).to(FLAGS.device), train_y_agg.unsqueeze(1))).squeeze()
    return output

# central kernel regression with gradient descent
def centralKRGD(train_x, train_y, public_x, test_x, FLAGS):
    if FLAGS.data == "toy1":
        k = k1
    elif FLAGS.data == "toy3":
        k = k3
    elif FLAGS.data == "toy3_2":
        k = k3_2
    train_x_agg = torch.cat(train_x, dim = 0).to(FLAGS.device)
    train_y_agg = torch.cat(train_y, dim = 0).to(FLAGS.device)
    k_tx = k(test_x.to(FLAGS.device), train_x_agg).to(FLAGS.device)
    k_xx = k(train_x_agg, train_x_agg).to(FLAGS.device)
    k_xx_ = torch.eye(len(k_xx)).to(FLAGS.device) - FLAGS.eta * k_xx / len(k_xx)
    k_xx_power = torch.linalg.matrix_power(k_xx_, int(FLAGS.T))
    k_xx_2 = torch.linalg.pinv(k_xx) @ (torch.eye(len(k_xx)).to(FLAGS.device) - k_xx_power)
    output = (k_tx @ k_xx_2 @ train_y_agg.unsqueeze(1)).squeeze()
    return output

# iterative ensemble distillation with de-regularization (Park et al. 2023)
def IED(train_x, train_y, public_x, test_x, FLAGS):
    if FLAGS.data == "toy1":
        k = k1
    elif FLAGS.data == "toy3":
        k = k3
    elif FLAGS.data == "toy3_2":
        k = k3_2
    k_pp = k(public_x, public_x).to(FLAGS.device)
    dr = torch.eye(len(k_pp)).to(FLAGS.device) + len(k_pp) * FLAGS.lamda * torch.linalg.pinv(k_pp)
    k_xpxp_inv = []
    k_pred_pub = []
    k_pred_test = []
    for i in range(len(train_x)):
        train_x_ = torch.cat([train_x[i], public_x], dim = 0).to(FLAGS.device)
        # if use default alpha then FLAGS.alpha = None
        if FLAGS.alpha == None:
            k_xpxp_inv_ = torch.linalg.inv(k(train_x_, train_x_) + FLAGS.lamda * len(train_x_) * torch.eye(len(train_x_)).to(FLAGS.device))
        else:
            reg = torch.diag(torch.tensor([float(len(train_x[i])) * FLAGS.lamda / FLAGS.alpha if j < len(train_x[i]) else float(len(public_x)) * FLAGS.lamda / (1-FLAGS.alpha) for j in range(len(train_x_))]))
            k_xpxp_inv_ = torch.linalg.inv(k(train_x_, train_x_) + reg.to(FLAGS.device))
        k_xpxp_inv.append(k_xpxp_inv_)
        k_pxp = k(public_x.to(FLAGS.device), train_x_)
        k_pred_pub.append(k_pxp @ k_xpxp_inv_)
        k_txp = k(test_x.to(FLAGS.device), train_x_)
        k_pred_test.append(k_txp @ k_xpxp_inv_)
    pred = torch.zeros(len(public_x)).to(FLAGS.device)
    data_ratio = torch.tensor([len(train_x_) for train_x_ in train_x]) / torch.tensor([len(train_x_) for train_x_ in train_x]).sum()
    for _ in range(700):
        # de-regularization
        pred = (dr @ pred.unsqueeze(1)).squeeze()
        
        pred_y = [torch.cat([train_y_.to(FLAGS.device), pred], dim = 0) for train_y_ in train_y]
        pred = torch.zeros(len(public_x)).to(FLAGS.device)
        for i in range(len(train_x)):
            pred += data_ratio[i] * (k_pred_pub[i] @ pred_y[i].unsqueeze(1)).squeeze()
    pred_y = [torch.cat([train_y_.to(FLAGS.device), pred], dim = 0) for train_y_ in train_y]
    output = []
    for i in range(len(train_x)):
        output.append((k_pred_test[i] @ pred_y[i].unsqueeze(1)).squeeze())
    return output

# DC-NY & DKRR-NY-CM
def DKRR_NY_CM(train_x, train_y, public_x, test_x, FLAGS):
    if FLAGS.data == "toy1":
        k = k1
    elif FLAGS.data == "toy3":
        k = k3
    elif FLAGS.data == "toy3_2":
        k = k3_2
    k_pp = k(public_x, public_x).to(FLAGS.device)
    k_px = [k(public_x, train_x[i]).to(FLAGS.device) for i in range(len(train_x))]
    k_tp = k(test_x, public_x).to(FLAGS.device)
    data_num = torch.tensor([len(train_x_) for train_x_ in train_x])
    data_sum = torch.tensor([len(train_x_) for train_x_ in train_x]).sum()
    data_ratio = data_num / data_sum
    coeff = torch.zeros(len(public_x)).to(FLAGS.device)
    nystrom = []
    nystrom_inv = []
    for i in range(len(train_x)):
        tmp = (k_px[i] @ k_px[i].T / data_num[i]) + (FLAGS.lamda * k_pp)
        # print(k_px[i] @ k_px[i].T / data_num[i])
        nystrom.append(tmp)
        nystrom_inv.append(torch.linalg.pinv(tmp))
    coeff = torch.zeros(len(public_x)).to(FLAGS.device)
    coeff_b = torch.zeros(len(public_x)).to(FLAGS.device)
    for i in range(len(train_x)):
        tmp = k_px[i] @ train_y[i].to(FLAGS.device).unsqueeze(1) / data_sum
        coeff += (nystrom_inv[i] @ tmp).squeeze()
        if FLAGS.T != 0:
            coeff_b += tmp.squeeze()
    nystrom_sum = torch.stack([r * tensor for r, tensor in zip(data_ratio, nystrom)], dim = 0).sum(dim = 0)
    nystrom_inv_sum = torch.stack([r * tensor for r, tensor in zip(data_ratio, nystrom_inv)], dim = 0).sum(dim = 0)
    if FLAGS.T != 0:
        coeff_A = torch.eye(len(public_x)).to(FLAGS.device) - FLAGS.eta * (nystrom_inv_sum @ nystrom_sum)
        coeff_b = (nystrom_inv_sum @ coeff_b.unsqueeze(1)).squeeze()
        for _ in range(FLAGS.T):
            coeff = (coeff_A @ coeff.unsqueeze(1)).squeeze() + FLAGS.eta * coeff_b
    output = (k_tp @ coeff.unsqueeze(1)).squeeze()
    return output

# DCL-KR
def DCL_KR(train_x, train_y, public_x, test_x, FLAGS):
    if FLAGS.data == "toy1":
        k = k1
    elif FLAGS.data == "toy3":
        k = k3
    elif FLAGS.data == "toy3_2":
        k = k3_2
    data_ratio = torch.tensor([len(train_x_) for train_x_ in train_x]) / torch.tensor([len(train_x_) for train_x_ in train_x]).sum()
    k_pp = k(public_x, public_x).to(FLAGS.device)
    k_pp_inv = torch.linalg.pinv(k_pp)
    weight_coeff = []
    weight_y = []
    for i in range(len(train_x)):
        k_xx = k(train_x[i].to(FLAGS.device), train_x[i].to(FLAGS.device))
        k_px = k(public_x.to(FLAGS.device), train_x[i].to(FLAGS.device))
        weight_ = k_pp_inv @ k_px @ torch.linalg.pinv(k_xx) @ (torch.eye(len(train_x[i])).to(FLAGS.device) - torch.linalg.matrix_power(torch.eye(len(train_x[i])).to(FLAGS.device) - FLAGS.eta * k_xx / len(train_x[i]), FLAGS.local_iters))
        weight_coeff.append(torch.eye(len(public_x)).to(FLAGS.device) - weight_ @ k_px.T)
        weight_y.append(weight_)
    coeff = torch.zeros(len(public_x)).to(FLAGS.device)
    for _ in range(FLAGS.T):
        coeff_ = copy.deepcopy(coeff)
        coeff = torch.zeros(len(public_x)).to(FLAGS.device)
        for i in range(len(train_x)):
            coeff += data_ratio[i] * (weight_coeff[i] @ coeff_.unsqueeze(1) + weight_y[i] @ train_y[i].to(FLAGS.device).unsqueeze(1)).squeeze()
    output = (k(test_x.to(FLAGS.device), public_x.to(FLAGS.device)) @ coeff.unsqueeze(1)).squeeze()
    return output

def data_gen(FLAGS):
    # clients and data number
    m = FLAGS.num_clients
    n = 50 * m
    
    # quantities
    if FLAGS.data == "toy1":
        r = 1
        s = 0.5
        npublic = int(FLAGS.public_scale * n ** (1/(2*r+s)) * math.log10(n) ** 3)
    elif FLAGS.data == "toy3":
        r = 1
        s = 0.75
        npublic = int(FLAGS.public_scale * n ** (1/(2*r+s)) * math.log10(n) ** 3)
    elif FLAGS.data == "toy3_2":
        r = 2/3
        s = 1/2
        npublic = int(FLAGS.public_scale * n ** (1/(2*r+s)) * math.log10(n) ** 3)
    else:
        print(f"Not Implemented Dataset")
        exit()

    # whole data generating
    if FLAGS.data == "toy1":
        train_x = torch.rand(n)
        train_y = torch.zeros(n)
        for i in range(1, 2001):
            train_y += 1. / (i ** 3) * (2 ** 0.5) * torch.sin((2*i-1) * torch.pi * train_x / 2.)
        train_y = train_y + 0.44 * torch.randn(n)
        
        public_x = torch.rand(npublic)
        if FLAGS.public_hetero != 0.:
            assert (FLAGS.hetero < 2) and (FLAGS.hetero > 0)
            public_x = ((FLAGS.hetero ** 2 + 4 * (1-FLAGS.hetero) * public_x) ** 0.5 - FLAGS.hetero)/(2 * (1-FLAGS.hetero))

        test_x = torch.rand(1000)
        test_y = torch.zeros(1000)
        for i in range(1, 2001):
            test_y += 1. / (i ** 3) * (2 ** 0.5) * torch.sin((2*i-1) * torch.pi * test_x / 2.)
    elif FLAGS.data == "toy3": 
        train_x = torch.rand(n, 3)
        train_y_ = torch.norm(train_x, dim = 1)
        train_y = torch.where(train_y_>1, 0, (1-train_y_)**6 * (35.*train_y_**2 + 18.*train_y_+ 3.)).squeeze()
        train_y = train_y + 0.44 * torch.randn(n)
        
        public_x = torch.rand(npublic, 3)
        if FLAGS.public_hetero != 0.:
            assert (FLAGS.public_hetero < 2) and (FLAGS.public_hetero > 0)
            public_x = ((FLAGS.public_hetero ** 2 + 4 * (1-FLAGS.public_hetero) * public_x) ** 0.5 - FLAGS.public_hetero)/(2 * (1-FLAGS.public_hetero))

        test_x = torch.rand(1000, 3)
        test_y_ = torch.norm(test_x, dim = 1)
        test_y = torch.where(test_y_>1, 0, (1-test_y_)**6 * (35.*test_y_**2 + 18.*test_y_+ 3.)).squeeze()

    while True:
        selected_flag = [np.random.choice(8, 2, replace = False) for _ in range(m)]
        violate = [sum([(i in s) for s in selected_flag]) for i in range(8)]
        if 0 not in violate:
            break
    data_ratio = np.random.dirichlet(np.repeat(10., m))

    if FLAGS.data == "toy1":
        flag = (train_x * 8).to(torch.int)
    elif FLAGS.data == "toy3": 
        flag = (train_x[:, 0] > 0.5).to(torch.int) * 4 + (train_x[:, 1] > 0.5).to(torch.int) * 2 + (train_x[:, 2] > 0.5).to(torch.int)
    train_x_sep = [train_x[flag == i] for i in range(8)]
    train_y_sep = [train_y[flag == i] for i in range(8)]
    train_x_client = [[] for _ in range(m)]
    train_y_client = [[] for _ in range(m)]
    for i in range(8):
        selected_flag_ = [(i in s) for s in selected_flag]
        selected_index = np.arange(m)[selected_flag_]
        data_ratio_ = np.cumsum(data_ratio[selected_index]) / np.cumsum(data_ratio[selected_index])[-1]
        data_num = (len(train_x_sep[i]) * data_ratio_).astype(np.int32)
        pre_num = 0
        for idx, num in zip(selected_index, data_num):
            train_x_client[idx].append(train_x_sep[i][pre_num:num])
            train_y_client[idx].append(train_y_sep[i][pre_num:num])
            pre_num = num
    for i in range(m):
        train_x_client[i] = torch.cat(train_x_client[i], dim = 0)
        train_y_client[i] = torch.cat(train_y_client[i], dim = 0)

    return train_x_client, train_y_client, public_x, test_x, test_y

if __name__ == "__main__":
    
    arg_command = sys.argv[1:]
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", type=int, default=0)
    parser.add_argument("--data", type=str, default='toy1')
    parser.add_argument("--algorithm", type=str, default='centralKRR')
    parser.add_argument("--local_iters", type=int, default=5) # local iters for FL
    parser.add_argument("--eta", type=float, default=1.) # learning rate for kernel learning
    parser.add_argument("--alpha", type=float, default=None)
    parser.add_argument("--T", type=int, default=None) # epochs (communication rounds) for FL
    parser.add_argument("--num_clients", type=int, default=10)
    parser.add_argument("--lamda", type=float, default=None)
    parser.add_argument("--public_scale", type=float, default = 1.)
    parser.add_argument("--public_hetero", type=float, default = 0., help='0 < parameter < 2')
    
    FLAGS, _ = parser.parse_known_args(arg_command)

    if FLAGS.cuda != -1:
        FLAGS.device = f'cuda:{FLAGS.cuda}'
        device = f'cuda:{FLAGS.cuda}'
    else:
        FLAGS.device = "cpu"
        device = "cpu"

    # quantity r, s
    if FLAGS.data == "toy1":
        r = 1
        s = 0.5
    elif FLAGS.data == "toy3":
        r = 1
        s = 3/4

    # default hyperparameters
    if FLAGS.T == None:
        FLAGS.T = int((FLAGS.num_clients * 50) ** (1/(2 * r + s)))
    if FLAGS.lamda == None:
        FLAGS.lamda = 1 / ((FLAGS.num_clients * 50) ** (1/(2 * r + s)))
        
    # loss function
    loss_fn = torch.nn.MSELoss().to(device)

    if FLAGS.algorithm == "centralKRR":
        algorithm = centralKRR
    elif FLAGS.algorithm == "centralKRGD":
        algorithm = centralKRGD
    elif FLAGS.algorithm == "IED":
        algorithm = IED
    elif FLAGS.algorithm == "DCL_KR":
        algorithm = DCL_KR
    elif FLAGS.algorithm == "DKRR_NY_CM":
        algorithm = DKRR_NY_CM
    else:
        print("Not Implemented Algorithm")
        exit()

    MSE = []
    with torch.no_grad():
        for _ in range(500):
            train_x, train_y, public_x, test_x, test_y = data_gen(FLAGS)
            output = algorithm(train_x, train_y, public_x, test_x, FLAGS)
            if type(output) == list:
                u = 0
                for output_ in output:
                    u += loss_fn(output_, test_y.to(device)).detach().cpu().numpy()
                u = u / len(output)
                MSE.append(u)
            else:
                MSE.append(loss_fn(output, test_y.to(device)).detach().cpu().numpy())
    if FLAGS.algorithm in ["centralKRR", "IED"]:
        hyp_print = f"lamda {FLAGS.lamda}"
    elif FLAGS.algorithm in ["DKRR_NY_CM"]:
        hyp_print = f"T {FLAGS.T} lamda {FLAGS.lamda}"
    elif FLAGS.algorithm in ["centralKRGD"]:
        hyp_print = f"T {FLAGS.T} eta {FLAGS.eta}"
    elif FLAGS.algorithm in ["DCL_KR"]:
        hyp_print = f"T {FLAGS.T} eta {FLAGS.eta} local_iters {FLAGS.local_iters}"
    print(f"Algorithm {FLAGS.algorithm} Data {FLAGS.data} #Clients {FLAGS.num_clients} / Hyperparameters : {hyp_print}...")
    print(f"Performance Mean : MSE {np.array(MSE).mean():.5f}, {np.array(MSE).std():.5f} :: RMSE {(np.array(MSE) ** 0.5).mean():.6f}, {(np.array(MSE) ** 0.5).std():.6f}")