import os
import json
import pickle
import torch
from torch.utils.data import DataLoader, TensorDataset
from json import JSONEncoder
from tqdm import tqdm

from fedsd2c.fedsd2c_utils import means, stds, lr_cosine_policy
from postprocessing.recorder import Recorder
from utils import Logger
from utils.fed_utils import assign_dataset
from utils.models import *
from copy import deepcopy

from fed_baselines.server_base import FedServer


class FedD3Server(FedServer):
    def __init__(self, args, client_list, dataset_id, model_name):
        """
        Server in the federated learning for FedD3
        """
        super().__init__(args, client_list, dataset_id, model_name)
        # Server Properties
        self._id = "server"
        self._dataset_id = dataset_id
        self._model_name = model_name
        self._i_seed = args.sys_i_seed

        # Training related parameters
        self._epoch = args.server_n_epoch
        self._batch_size = args.server_bs
        self._lr = args.server_lr
        self._momentum = args.server_momentum
        self._num_workers = args.server_n_worker
        self.optim_name = args.server_optimizer

        # Global test dataset
        self._test_data = None

        # Global distilled dataset
        self._distill_data = None

        # Following private parameters are defined by dataset.
        self.model = None
        _, self._image_length, self._image_channel = assign_dataset(dataset_id)
        self._image_width = self._image_length

        self.normalizer = transforms.Compose([
            # augmentation.RandomCrop(size=(self._image_dim, self._image_dim), padding=4),
            # augmentation.RandomHorizontalFlip(),
            transforms.Normalize(mean=means[self._dataset_id], std=stds[self._dataset_id])
        ])
        self.recorder = Recorder()

    def load_distill(self, data):
        """
        Server loads the decentralized distilled dataset.
        :param data: Dataset for training.
        """
        self._distill_data = {}
        self._distill_data = deepcopy(data)

    def train(self, exp_dir, res_root='results', i_seed=0):
        """
        Server trains models on the decentralized distilled datasets from networks
        :param exp_dir: Experiment directory name
        :param res_root: Result directory root for saving the result files
        :param i_seed: Index of the used seed
        :return: Loss in the training.
        """
        torch.manual_seed(self._i_seed)
        np.random.seed(self._i_seed)
        state_dict_list = []

        # Create the train and test loader
        with torch.no_grad():

            train_x = self.normalizer(self._distill_data['x'].type(torch.FloatTensor).squeeze())
            if len(train_x.shape) == 3:
                train_x = train_x.unsqueeze(1)
            train_y = self._distill_data['y'].type(torch.LongTensor).squeeze()

            train_loader = DataLoader(TensorDataset(train_x, train_y), batch_size=self._batch_size, shuffle=True)

            self.model.to(self._device)
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self._lr, momentum=self._momentum, weight_decay=1e-4)
            # optimizer = torch.optim.Adam(self.model.parameters(), lr=self._lr, weight_decay=1e-4)
            loss_func = nn.CrossEntropyLoss()
            lr_scheduler = lr_cosine_policy(self._lr, 0, self._epoch)

        # Train process
        pbar = tqdm(range(self._epoch))
        for epoch in pbar:
            lr_scheduler(optimizer, epoch, epoch)
            for step, (x, y) in enumerate(train_loader):
                with torch.no_grad():
                    b_x = x.to(self._device)  # Tensor on GPU
                    b_y = y.to(self._device)  # Tensor on GPU

                with torch.enable_grad():
                    self.model.train()
                    output = self.model(b_x)
                    loss = loss_func(output, b_y)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            # Test process
            acc = self.test()
            L = Logger()
            logger = L.get_logger()
            logger.info('Epoch: %d' % epoch + ' / %d ' % self._epoch +
                        '| Train loss: %.4f ' % loss.data.cpu().numpy() +
                        '| Accuracy: %.4f ' % acc +
                        '| Max Acc: %.4f ' % np.max(np.array(self.recorder.res['server']['iid_accuracy'])))

    def load_cls_record(self, cls_record):
        """
        Client loads the statistic of local label.
        :param cls_record: class number record
        """
        pass
