#!/usr/bin/env python
import os
import random
import json
import pickle
import argparse
from collections import defaultdict
from json import JSONEncoder

import torch
from tqdm import tqdm

from utils import Logger, fed_args, read_config, log_config

args = fed_args()
args = read_config(args.config, args)
if Logger.logger is None:
    L = Logger()
    if not os.path.exists("train_records/"):
        os.makedirs("train_records/")
    L.set_log_name(os.path.join("train_records", "train_record_" + args.save_name + ".log"))
    logger = L.get_logger()
    log_config(args)

from fed_baselines.client_base import EnsembleClient
from fed_baselines.server_base import EnsembleServer
from fed_baselines.client_cvae import FedCVAEClient
from fed_baselines.server_cvae import FedCVAEServer
from fed_baselines.client_dense import DENSEClient
from fed_baselines.server_dense import DENSEServer
from fed_baselines.server_coboost import CoBoostServer
from fed_baselines.server_dafl import DAFLServer
from fed_baselines.server_fedd3 import FedD3Server
from fed_baselines.client_fedd3 import FedD3Client
from fed_baselines.server_adi import ADIServer
from fed_baselines.server_central import CentralServer
from fed_baselines.server_fedavg import FedAVGServer

from postprocessing.recorder import Recorder
from preprocessing.baselines_dataloader import divide_data, divide_data_with_dirichlet, divide_data_with_local_cls, load_data
from utils.models import *

json_types = (list, dict, str, int, float, bool, type(None))


class PythonObjectEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, json_types):
            return super().default(self, obj)
        return {'_python_object': pickle.dumps(obj).decode('latin-1')}


def as_python_object(dct):
    if '_python_object' in dct:
        return pickle.loads(dct['_python_object'].encode('latin-1'))
    return dct


def fed_run():
    """
    Main function for the baselines of federated learning
    """

    algo_list = ["FedCVAE", "DENSE", "ENSEMBLE", "CoBoost", "DAFL", "ADI", "Central", "FedAVG"]
    assert args.client_instance in algo_list, "The federated learning algorithm is not supported"

    dataset_list = ['TINYIMAGENET', "CIFAR100", 'Imagenette', "openImg"]
    assert args.sys_dataset in dataset_list, "The dataset is not supported"

    model_list = ["LeNet", 'AlexCifarNet', "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "CNN", "Conv4",
                  "Conv5", "Conv6"]
    assert args.sys_model in model_list, "The model is not supported"

    random.seed(args.sys_i_seed)
    np.random.seed(args.sys_i_seed)
    torch.manual_seed(args.sys_i_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(args.sys_i_seed)
    torch.set_num_threads(3)

    client_dict = {}
    recorder = Recorder()

    logger.info('======================Setup Clients==========================')
    if args.client_instance == 'Central':
        trainset_config = {'users': [],
                        'user_data': {},
                        'num_samples': []}
        trainset, testset, len_classes = load_data(args.sys_dataset, download=True, save_pre_data=False, aug=args.client_instance_aug)
        cls_record = {}
    if args.sys_n_local_class is not None and args.sys_dataset_dir_alpha is None:
        logger.info('Using divide data with local class')
        trainset_config, testset, cls_record = divide_data_with_local_cls(n_clients=args.sys_n_client,
                                                                          n_local_cls=args.sys_n_local_class,
                                                                          dataset_name=args.sys_dataset,
                                                                          seed=args.sys_i_seed,
                                                                          aug=args.client_instance_aug)
    elif args.sys_dataset_dir_alpha is not None:
        logger.info('Using divide data with dirichlet')
        trainset_config, testset, cls_record = divide_data_with_dirichlet(n_clients=args.sys_n_client,
                                                                          beta=args.sys_dataset_dir_alpha,
                                                                          dataset_name=args.sys_dataset,
                                                                          seed=args.sys_i_seed,
                                                                          aug=args.client_instance_aug)
    else:
        raise NotImplementedError("sys_n_local_class and sys_dataset_dir_alpha are both None")
    logger.info('Clients in Total: %d' % len(trainset_config['users']))

    # Initialize the clients w.r.t. the federated learning algorithms and the specific federated settings
    if args.client_instance == 'FedCVAE':
        server = FedCVAEServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'DENSE':
        server = DENSEServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'ENSEMBLE':
        server = EnsembleServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'CoBoost':
        server = CoBoostServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'DAFL':
        server = DAFLServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'FedD3':
        server = FedD3Server(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'ADI':
        server = ADIServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    elif args.client_instance == 'Central':
        server = CentralServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
        server.load_trainset(trainset)
    elif args.client_instance == 'FedAVG':
        server = FedAVGServer(args, trainset_config['users'], dataset_id=args.sys_dataset, model_name=args.sys_model)
    else:
        raise NotImplementedError('Server error')
    server.load_testset(testset)
    server.load_cls_record(cls_record)

    # Main process of federated learning in multiple communication rounds
    if args.sys_oneshot:
        n_round = 1
    else:
        n_round = 1
    # pbar = tqdm(range(n_round))
    pbar = range(n_round)
    # Initialize the clients w.r.t. the federated learning algorithms and the specific federated settings
    losses = defaultdict(list)
    distill_dataset = {'x': [], 'y': []}
    logger.info('--------------------- configuration stage ---------------------')
    for client_id in trainset_config['users']:
        if args.client_instance == 'FedCVAE':
            client_dict[client_id] = FedCVAEClient(args, client_id, epoch=args.client_instance_n_epoch,
                                                   dataset_id=args.sys_dataset, model_name=args.sys_model)
            server.client_dict[client_id] = client_dict[client_id]
            client_dict[client_id].load_trainset(trainset_config['user_data'][client_id])
            _, recon_loss, kld_loss = client_dict[client_id].train(client_dict[client_id].model)
            losses["recon_loss"].append(recon_loss)
            losses["kld_loss"].append(kld_loss)
        elif args.client_instance in ['DENSE', 'ENSEMBLE', "CoBoost", "DAFL", "ADI", "FedAVG"]:
            client_class = DENSEClient if args.client_instance == 'DENSE' else EnsembleClient
            client_dict[client_id] = client_class(args, client_id, epoch=args.client_instance_n_epoch,
                                                  dataset_id=args.sys_dataset, model_name=args.sys_model)
            server.client_dict[client_id] = client_dict[client_id]
            client_dict[client_id].load_trainset(trainset_config['user_data'][client_id])
            server.client_model[client_id] = client_dict[client_id].model
            if args.client_model_root is not None and os.path.exists(
                    os.path.join(args.client_model_root, f"c{client_id}.pt")):
                weight = torch.load(os.path.join(args.client_model_root, f"c{client_id}.pt"), map_location="cpu")
                if "model" in weight:
                    weight = weight["model"]
                logger.info("Load Client {} from {}".format(client_id,
                                                            os.path.join(args.client_model_root, f"c{client_id}.pt")))
                client_dict[client_id].model.load_state_dict(weight)
            else:
                c_loss = client_dict[client_id].train(client_dict[client_id].model)
                if not os.path.exists(os.path.join(args.sys_res_root, args.save_name)):
                    os.mkdir(os.path.join(args.sys_res_root, args.save_name))
                if args.save_client_model and args.client_model_root is None:
                    torch.save(client_dict[client_id].model.state_dict(),
                               os.path.join(args.sys_res_root, args.save_name, f"c{client_id}.pt"))
                losses["client_loss"].append(c_loss)
        elif args.client_instance == "FedD3":
            client_dict[client_id] = FedD3Client(client_id, args.sys_dataset)
            ret_data = client_dict[client_id].kip_distill(
                args.client_n_dd,
                num_train_steps=args.fedd3_max_n_epoch,
                seed=args.sys_i_seed,
                lr=args.client_instance_lr,
                threshold=args.fedd3_threshold,
                target_sample_size=args.fedd3_bs)
            for k_data_point in ret_data:
                distill_dataset['y'].append(k_data_point[0])
                distill_dataset['x'].append(k_data_point[1])

    if args.client_instance == "FedD3":
        distill_dataset['x'] = torch.tensor(distill_dataset['x'])
        distill_dataset['x'] = distill_dataset['x'].permute(0, 3, 1, 2)
        distill_dataset['y'] = torch.tensor(distill_dataset['y'])
        server.load_distill(distill_dataset)
        logger.info('Server gets %d images' % len(distill_dataset['y']))

    for l_name, loss_list in losses.items():
        logger.info('{}: {}'.format(l_name, [float(l) for l in loss_list]))

    server.train()


if __name__ == "__main__":
    fed_run()
