import numpy as np
import pickle
import json
import datetime
import argparse
import warnings
import importlib
import torch
import os
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
import random

import warnings
import pickle
from work_utils import MiniDataset
from data.finetune.ft_loader import get_loader
warnings.filterwarnings('ignore')
import time
import datetime



parser = argparse.ArgumentParser()

parser.add_argument('--seed',
                        help='random seed',
                        default = 564,
                        type=int)

parser.add_argument('--algo',
                    help='name of trainer;',
                    type=str,
                    default='FedJNB')

parser.add_argument('--dataset',
                    help='name of dataset;',
                    type=str,
                    # default='mnist_all_data_1_equal_niid')
                    default = 'cifar10')

parser.add_argument('--device',
                        help='selected CUDA device',
                        default='cuda:0',
                        type=str)

parser.add_argument('--num_classes',
                        help='number of classes',
                        default=10,
                        type=int)

parser.add_argument('--num_clients',
                        help='number of clients',
                        default=100,
                        type=int)



parser.add_argument('--back_bone',
                        help='backbone',
                        default='lenet',
                        type=str)

parser.add_argument('--optim',
                        help='optmizer',
                        default='Adam',
                        type=str)

parser.add_argument('--lr',
                        help='learning rate',
                        default=0.01,
                        type=float)



parser.add_argument('--weight_decay',
                        help='weight decay',
                        default=0,
                        type=float)


parser.add_argument('--rounds',
                        help='number of federated rounds',
                        default=100,
                        # default = 2,
                        type=int)


parser.add_argument('--local_epoch',
                        help='number of local training steps',
                        default=1,
                        type=int)

parser.add_argument('--batch_size',
                        help='batch size',
                        default=1024,
                        type=int)


parser.add_argument('--J_norm_coef',
                    help='coefficient for regularization on Jacobian norm;',
                    type=float,
                    default=0.001)

parser.add_argument('--J_ind_coef',
                    help='coefficient for regularization on Jacobian (indexwise);',
                    type=float,
                    default=0.001)

parser.add_argument('--num_participants',
                        help = 'number of clients participating each round',
                        default = 10,
                        type = int)

parser.add_argument('--num_train_base',
                    help='number of clients who train a base FedAVG',
                    type=float,
                    default= 10)


parser.add_argument('--ft_lr',
                        help='finetune learning rate',
                        default=0.001,
                        type=int)

parser.add_argument('--ft_wd',
                        help='finetune weigth decay',
                        default=0,
                        type=int)

parser.add_argument('--ft_bs',
                        help='finetune batch size',
                        default=1024,
                        type=int)

parser.add_argument('--ft_epochs',
                        help='finetune epochs',
                        # default=200,
                        default = 2,
                        type=int)


parser.add_argument('--early_stop',
                        help='early stop',
                        default=5,
                        type=int)

parser.add_argument('--eval_every',
                        help='print cur performance',
                        default = 5,
                        type=int)

parser.add_argument('--save_every',
                        help='print cur performance',
                        default = 5,
                        type=int)

def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size

def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))

    flat_params = torch.cat(params)
    return flat_params

def aggregate(solns):
        averaged_solution = torch.zeros_like(solns[0])
        for local_solution in solns:
            averaged_solution +=  local_solution
        averaged_solution /= len(solns)
        return averaged_solution.detach()


def freeze(model, k):
    # only fine-tune the last k fc layers (if there are more than k fc layers)
    num_layer = 0
    for mod in model.children():
        for params in mod.parameters():
            params.requires_grad = False
        num_layer += 1

    for mod in model.children():
        num_layer -= 1
        if num_layer < k and isinstance(mod, torch.nn.Linear):
            for params in mod.parameters():
                params.requires_grad = True



def eval(server, data_loader):
    # net.eval(); cls.eval()
    eval_loss, eval_acc, eval_total = 0, 0, 0
    for x, y in data_loader:
        x, y = x.cuda(), torch.tensor(y).cuda()
        
        logits = server.net(x)
        loss = F.cross_entropy(logits,y)
        acc = (logits.argmax(1)==y).float().mean()
        
        eval_loss += loss.item() * y.size(0)
        eval_acc += acc.item() * y.size(0)
        eval_total += y.size(0)

    return eval_loss/eval_total, eval_acc/eval_total


def ft_train(server, args, device, train_loader, test_loader):
    
    best_acc, best_loss = 0, float('inf')
    training_loss, training_acc, testing_loss, testing_acc, patience = [],[],[],[],0
    
    cur_loss, cur_acc, cur_total = 0,0,0 
    for i_epoch in range(args.ft_epochs):  
        for x, y in train_loader:
            x, y = x.cuda(), torch.tensor(y).cuda()
            
            logits = server.net(x)
            loss = F.cross_entropy(logits,y)

            server.optim.zero_grad()
            loss.backward()
            server.optim.step()

            acc = (logits.argmax(1)==y).float().mean()
            
            cur_loss += loss.item() * y.size(0)
            cur_acc += acc.item() * y.size(0)
            cur_total += y.size(0)
        
        epoch_tr_loss, epoch_tr_acc = cur_loss/cur_total, cur_acc/cur_total
        training_loss.append(epoch_tr_loss)
        training_acc.append(epoch_tr_acc)

        epoch_tst_loss, epoch_tst_acc = eval(server, test_loader)
        testing_loss.append(epoch_tst_loss)
        testing_acc.append(epoch_tst_acc)
        
        # update best performance
        best_loss = epoch_tst_loss if epoch_tst_loss < best_loss else best_loss
        best_acc = epoch_tst_acc if epoch_tst_acc > best_acc else best_acc

        # early stop
        if len(training_acc) > 1 and training_acc[-1] <= training_acc[-2]:
            patience += 1
        else:
            patience = 0
        
        if patience >= args.early_stop:
            print("Finetune early stop at epoch", i_epoch)
            break
        
        
        if (i_epoch + 1) % args.eval_every == 0:
            print(f"Epoch: {i_epoch + 1:03d}, Train_loss: {epoch_tr_loss:.4f}, "
                  f"Train_acc: {epoch_tr_acc:.4f}, "
                  f"Test_loss: {epoch_tst_loss:.4f}, "
                  f"Test_acc: {epoch_tst_acc:.4f}")
    
    # store FT performance
    with open(FT_performance_path + ".json", "w") as outfile: 
        json.dump({'test_loss':testing_loss, 'test_acc':testing_acc, 
                   'train_loss':training_loss, 'train_acc':training_acc}, outfile)

    print("\n FT Performance Saved to", FT_performance_path + ".json")
   
    print(">>> Best performance on Target Domain: Acc", best_acc,"Loss:", best_loss)
    print(">>> Fine-tuning done!")




# load and print argumemts
args = parser.parse_args()

max_length = max([len(key) for key in args.__dict__.keys()])
fmt_string = '\t%' + str(max_length) + 's : %s'
print('>>> Arguments:')
for keyPair in sorted(args.__dict__.items()):
    print(fmt_string % keyPair)



# set random seeds
np.random.seed(args.seed)
torch.manual_seed(10 + args.seed)
torch.cuda.manual_seed_all(100 + args.seed)

# set paths
if args.num_clients == 10:
    train_path = "data_c10/cifar10/train/all_data_equal_niid.pkl"
    test_path = "data_c10/cifar10/test/all_data_equal_niid.pkl"
elif args.num_clients == 100:
    train_path = "data/cifar10/data/train/all_data_equal_niid.pkl"
    test_path = "data/cifar10/data/test/all_data_equal_niid.pkl"
else:
    raise NotImplementedError

uid = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
FL_performance_path = f'./results/FL_performance/{uid}_'
FT_performance_path = f'./results/FT_performance/{uid}_'
FL_model_path = f'./model_results/FL/{uid}_'
FT_model_path = f'./model_results/FT/{uid}_'

# read data
start_time_read_data = time.time()

clients = []
groups = []
train_data = {}
test_data = {}
print('>>> Read training data from:')
print(train_path)

# data storage form: {'users': , 'user_data': , 'num_samples':}
with open(train_path, 'rb') as inf:
    cdata = pickle.load(inf)

train_data.update(cdata['user_data'])

for cid, v in train_data.items():
    train_data[cid] = MiniDataset(v['x'], v['y'], args.dataset)
    # transformation assigned, but no transformation implemented yet

print('>>> Read testing data from:')
print(test_path)
with open(test_path, 'rb') as inf:
    cdata = pickle.load(inf)
test_data.update(cdata['user_data'])

for cid, v in test_data.items():
    test_data[cid] = MiniDataset(v['x'], v['y'], args.dataset)

clients = list(sorted(train_data.keys()))


end_time_read_data = time.time()
print("#######################################################")
print("Loading data time cost: ", str(datetime.timedelta(seconds=end_time_read_data - start_time_read_data)))
print("#######################################################")

start_time_FL = time.time()

print(">>> Federated Learing on Source Domain")
print("\n FL Performance will be saved to", FL_performance_path + ".json")
print(f"FL Pretrained Model (the best) will be saved at {FL_model_path}_net.pt")
local_clients = {}

# initialize global and local models
train_base_flag = True if args.num_train_base > 0 else False
global_server = importlib.import_module('methods.JNB').Model(args=args, train_base = train_base_flag).to(args.device)
for i in clients:
    local_clients[i] = importlib.import_module('methods.JNB').Model(args=args, train_base = train_base_flag).to(args.device)
    
# federated training
train_loss, train_acc, test_loss, test_acc, patience = [],[],[],[],0
train_J_norm, train_J_var = [],[]
train_regJ_ind, train_regJ_norm = [],[]
best_acc = 0
RegPole = None
for i_round in range(args.rounds):
    # select clients doing train and base train
    participants = random.sample(clients, args.num_participants)
    base_clients = random.sample(clients, args.num_train_base)
    if participants:
        print(f"===== Selected Client {participants} to train model.")
    else:
        raise NotImplementedError
    print(f"===== Selected Client {base_clients} to train BASE model." if base_clients else f"NO clients selected for BASE train.")
    
    cur_train_loss, cur_train_acc, cur_train_regJ_ind, cur_train_regJ_norm = [], [], [], []
    cur_net_solns, cur_base_net_solns = [], []
    cur_J_JNB, cur_J_norm, cur_J_base_norm = [], [], [-float('inf')]
    
    for i_client in participants:
        base_train_flag= True if i_client in base_clients else False
        
        if i_round > 0:
            # initialize the local model by the global model
            set_flat_params_to(local_clients[i_client].net, get_flat_params_from(global_server.net))
            # initialize local base model
            if base_train_flag:
                set_flat_params_to(local_clients[i_client].base_net, get_flat_params_from(global_server.base_net))
                
        # local train
        cur_stat = local_clients[i_client].train_client(loader = DataLoader(train_data[i_client], batch_size=args.batch_size, shuffle=True), steps = args.local_epoch, RegPole = RegPole, base_train = base_train_flag)
        # local test
        test_loader = DataLoader(test_data[i_client], batch_size=args.batch_size, shuffle=False)
        # print("check: load test data!")
        cur_loss, cur_acc = eval(local_clients[i_client], test_loader)
        print("Client", i_client,"local loss and acc:", cur_loss, cur_acc)
        
        
        cur_net_solns.append(get_flat_params_from(local_clients[i_client].net))
        if base_train_flag:
            cur_base_net_solns.append(get_flat_params_from(local_clients[i_client].base_net))

        # record training stats
        cur_train_loss.append(cur_stat['loss'])
        cur_train_acc.append(cur_stat['acc'])
        if RegPole is not None:
            cur_train_regJ_ind.append(cur_stat['regJ_ind'])
            cur_train_regJ_norm.append(cur_stat['regJ_norm'])
        
        cur_J_JNB.append(cur_stat['J_JNB'])
        cur_J_norm.append(torch.norm(cur_stat['J_JNB']))
        if i_client in base_clients:
            cur_J_base_norm.append(cur_stat['norm_J_base'])
        
    # update the global model by local models
    global_net_soln = aggregate(cur_net_solns)
    set_flat_params_to(global_server.net, global_net_soln)
    # update global base model
    if train_base_flag:
        global_base_net_soln = aggregate(cur_base_net_solns)
        set_flat_params_to(global_server.base_net, global_base_net_soln)

    # Update RegPole for next round
    # calculate avg grad of JNB and base
    global_J_JNB = torch.mean(torch.stack(cur_J_JNB, dim=0), dim=0)
    RegPole = global_J_JNB / torch.norm(global_J_JNB) * max(max(cur_J_norm), max(cur_J_base_norm))    
    norm_global_J = torch.norm(global_J_JNB)
    
    var_local_J = sum([(i ** 2) for i in cur_J_norm])/len(cur_J_norm) - norm_global_J ** 2
    
    # evaluate
    # evaluate on train
    train_loss.append((sum(cur_train_loss)/len(cur_train_loss)).item())
    train_acc.append((sum(cur_train_acc)/len(cur_train_acc)).item())
    if i_round > 0:
        train_regJ_ind.append((sum(cur_train_regJ_ind)/len(cur_train_regJ_ind)).item())
        train_regJ_norm.append((sum(cur_train_regJ_norm)/len(cur_train_regJ_norm)).item())
    
    # evaluate on test
    eval_loss, eval_acc, eval_loss2, eval_acc2 = 0,0,0,0
    for i_client in clients:
        test_loader = DataLoader(test_data[i_client], batch_size=args.batch_size, shuffle=False)
        test_loader2 = DataLoader(train_data[i_client], batch_size=args.batch_size, shuffle=False)
        
        # print("check: load test data!")
        cur_loss, cur_acc = eval(global_server, test_loader)
        cur_loss2, cur_acc2 = eval(global_server, test_loader2)
        eval_loss += cur_loss 
        eval_acc += cur_acc
        eval_loss2 += cur_loss2 
        eval_acc2 += cur_acc2
    
    print("Round",i_round, "test performance on trainset", eval_loss2/ args.num_clients, eval_acc2/ args.num_clients)
    print("Round",i_round, "test performance on testset", eval_loss/ args.num_clients, eval_acc/ args.num_clients)
    
    
    test_loss.append((eval_loss/ args.num_clients))
    test_acc.append((eval_acc/args.num_clients))
    train_J_norm.append(norm_global_J.item())
    train_J_var.append(var_local_J.item())
    
    # check and save the best model
    if test_acc[-1] > best_acc:
        best_acc = test_acc[-1]
        torch.save(global_server.net.state_dict(), FL_model_path + "_net.pt")
        
    # early stop
    if i_round != 0:
        if test_acc[-1] <= test_acc[-2]:
            patience +=1
        else:
            patience = 0

    if patience >= args.early_stop:
        print('federated training early stop at round', i_round)
        break


    # print performance
    if (i_round+1) % args.eval_every == 0:
        print(f"Epoch: {i_round + 1:03d},"
                f"Train_loss: {train_loss[-1]:.4f}, "
                f"Train_acc: {train_acc[-1]:.4f}, "
                f"Test_loss: {test_loss[-1]:.4f}, "
                f"Test_acc: {test_acc[-1]:.4f}")
    
    # periodically save trn & tst performance 
    if (i_round+1) % args.save_every == 0:
        with open(FL_performance_path + ".json", "w") as outfile: 
            json.dump({'train_loss':train_loss, 
                       'train_acc':train_acc, 
                       'test_loss':test_loss, 
                       'test_acc':test_acc,
                       'J_norm': train_J_norm,
                       'J_var': train_J_var}, outfile)


end_time_FL = time.time()
print("#######################################################")
print("FL for", args.rounds, "rounds, time cost: ", str(datetime.timedelta(seconds=end_time_FL - start_time_FL)))
print("#######################################################")
print("\n FL Performance Saved to", FL_performance_path + ".json")
print(f"FL Pretrained Model (the best) saved at {FL_model_path}_net.pt")
print(f"Best test acc is {best_acc}")

global_server.net.load_state_dict(torch.load(FL_model_path+ "_net.pt"))

# freeze the whole feature extracter
freeze(global_server.net, 1)

# load finetune dataset
print(">>> Reading Target Data...")
ft_train_loader, ft_test_loader = get_loader(image_dir='data/svhn', dataset_name='svhn', batch_size = args.ft_bs, num_workers=16)
print(">>> Reading Finished.")    


# fine-tuning
print('>>> Training last layer on target data')
start_time_FT = time.time()
ft_train(global_server, args, args.device, ft_train_loader, ft_test_loader)
end_time_FT = time.time()
print("#######################################################")
print("Finetune on target domain time cost: ", str(datetime.timedelta(seconds=end_time_FT - start_time_FT)))
print("#######################################################")


