import logging
import os

import torch

from .samplers import (test_iid_sampler_callback,
                       test_noniid_sampler_callback,
                       train_iid_sampler_callback,
                       train_noniid_sampler_callback)

from .attacks.alittle import ALittleIsEnoughAttack
from .attacks.signflipping import SignFlippingWorker
from .attacks.ipm import IPMWorker
from .attacks.labelflipping import LabelFlippingWorker

from .simulators.server import TorchServer
from .simulators.simulator import DistributedEvaluator


class Experiment:

    def __init__(self, args, device, use_cuda, dataset_dir):
        self._parse_args(args)
        self.device = device
        self.use_cuda = use_cuda
        self.dataset_dir = dataset_dir

    def get_results_path(self, root_dir):
        # build results path based on experimental parameters
        results_path = os.path.join(
            root_dir,
            'results',
            self.task_name,
            str(self.epochs),
            ('fixed_point' if self.fixed_point else 'floating_point'),
            ('linear' if self.use_linear_model else ''),
            ('noniid' if self.noniid else 'iid'),
            f"{'max' if self.examples_per_worker is None else self.examples_per_worker}_examples",
            self.agg,
            f"{'None' if self.attack is None else self.attack}",
            f"{'' if self.debug else 'n' + str(self.n)}",
            f"{'' if self.debug else 'f' + str(self.f)}",
            f"{'debug/debug' if self.debug else self._get_agg_results_path()}"
        )

        return results_path

    def set_task(self, task):
        # instantiate task
        if self.task_name == 'mnist':
            self.task = task(self.use_linear_model)
        else:
            self.task = task()

        self.server_model = self.task.get_model().to(self.device)
        self.loss_function = self.task.get_loss_function().to(self.device)
        self.dataloader = self.task.get_dataloader()
        logging.getLogger("debug").info(f"dataloader: {self.dataloader}")

    def get_trainer(self):
        """Initializes and returns trainer."""
        server_optimizer = torch.optim.SGD(
            params=self.server_model.parameters(),
            lr=self.lr,
            weight_decay=(0 if self.agg not in ['rsa'] else self.weight_decay)
        )
        server = TorchServer(optimizer=server_optimizer, model=self.server_model)

        trainer = self.trainer_type(
            server=server,
            aggregator=self._get_aggregator(),
            pre_batch_hooks=[],
            post_batch_hooks=[],
            max_batches_per_epoch=self.max_batches_per_epoch,
            log_interval=1,
            metrics=self._get_metrics(),
            use_cuda=self.use_cuda,
            debug=False
        )

        self._add_workers_to_trainer(trainer=trainer)

        return trainer

    def get_evaluator(self, **kwargs):
        if self.noniid:
            sampler_callback = test_noniid_sampler_callback()
        else:
            sampler_callback = test_iid_sampler_callback()

        test_loader = self.dataloader(
            data_dir=self.dataset_dir,
            train=False,
            download=True,
            batch_size=self.test_batch_size,
            shuffle=False,
            sampler_callback=sampler_callback,
            drop_last=False,
            **kwargs
        )

        evaluator = DistributedEvaluator(
            model=self.server_model,
            data_loader=test_loader,
            loss_func=self.loss_function,
            device=self.device,
            metrics=self._get_metrics(),
            use_cuda=self.use_cuda,
            debug=False
        )

        return evaluator

    def _parse_args(self, args):
        self.epochs = args.epochs

        self.noniid = args.noniid

        self.task_name = args.task
        if args.use_linear_model:
            assert args.task == 'mnist'
        self.use_linear_model = args.use_linear_model

        assert args.attack in ['sf', 'lf', 'ipm', 'alie', 'None']
        self.attack = args.attack

        assert 0 <= args.f < args.n
        self.n = args.n
        self.f = args.f

        self.debug = args.debug

        assert args.examples_per_worker is None or args.examples_per_worker >= 1
        self.examples_per_worker = args.examples_per_worker

        assert args.batch_size >= 1
        self.batch_size = args.batch_size
        assert args.test_batch_size >= 1
        self.test_batch_size = args.test_batch_size

        if args.task == 'mnist':
            assert args.lr == 0.01
            assert args.max_batches_per_epoch == 1
        # elif args.task == 'cifar10' or args.task == 'cifar100':
        #     assert args.lr == 0.1
        #     assert args.max_batches_per_epoch == 1
        self.lr = args.lr
        self.max_batches_per_epoch = args.max_batches_per_epoch

        if args.p_norm in ['1', '2']:
            args.p_norm = int(args.p_norm)
        else:
            assert args.p_norm == 'inf'
        self.p_norm = args.p_norm

        assert args.seed >= 0
        self.seed = args.seed

        assert self.trainer_type is not None
        assert self.worker_type is not None

        if args.fixed_point:
            assert self.agg in ('krum', 'cclip')
        self.fixed_point = args.fixed_point

    def _get_agg_results_path(self):
        raise NotImplementedError

    def _get_aggregator(self):
        raise NotImplementedError

    def _get_metrics(self):
        def top1_accuracy(outputs, targets):
            def accuracy(output, target, topk=(1,)):
                """Computes the precision@k for the specified values of k"""
                maxk = max(topk)
                batch_size = target.size(0)

                _, pred = output.topk(maxk, 1, True, True)
                pred = pred.t()
                correct = pred.eq(target.view(1, -1).expand_as(pred))

                res = []
                for k in topk:
                    correct_k = correct[:k].view(-1).float().sum(0)
                    res.append(correct_k.mul_(100.0 / batch_size))
                return res

            return accuracy(outputs, targets, topk=(1,))[0].item()

        metrics = {"top1": top1_accuracy}

        return metrics

    def _init_worker(self, train_loader, is_byzantine, worker_model, worker_optimizer):
        if self.task_name == "emnist62":
            max_class = 61
        elif self.task_name == "emnist":
            max_class = 25
        elif self.task_name == "cifar100":
            max_class = 99
        elif self.task_name == "mnist":
            max_class = 9
        else:
            raise NotImplemented(self.task_name)

        if is_byzantine:
            if self.attack == 'sf':
                worker_type = SignFlippingWorker
                worker_kwargs = {}
            elif self.attack == 'lf':
                worker_type = LabelFlippingWorker
                worker_kwargs = {"revertible_label_transformer": lambda target: max_class - target}
            elif self.attack == 'ipm':
                worker_type = IPMWorker
                worker_kwargs = {"epsilon": 0.1}
            elif self.attack == 'alie':
                worker_type = ALittleIsEnoughAttack
                worker_kwargs = {"n": self.n, "m": self.f}
            else:
                raise NotImplementedError(self.attack)
        else:
            worker_type = self.worker_type
            worker_kwargs = self.worker_kwargs

        return worker_type(data_loader=train_loader,
                           model=worker_model,
                           optimizer=worker_optimizer,
                           loss_func=self.loss_function,
                           device=self.device,
                           **worker_kwargs)

    def _add_workers_to_trainer(self, trainer):
        if self.agg in ['rsa', 'fltrust']:
            worker_models = [self.task.get_model().to(self.device) for _ in range(self.n)]
        else:
            worker_models = [self.server_model for _ in range(self.n)]
        worker_optimizers = [torch.optim.SGD(params=worker_models[i].parameters(),
                                             lr=self.lr) for i in range(self.n)]

        n_good = self.n - self.f
        for worker_rank in range(self.n):
            if self.noniid:
                sampler_callback = train_noniid_sampler_callback(worker_rank,
                                                                 n_good,
                                                                 self.examples_per_worker)
            else:
                sampler_callback = train_iid_sampler_callback(worker_rank,
                                                              self.n,
                                                              self.examples_per_worker)

            train_loader = self.dataloader(data_dir=self.dataset_dir,
                                           train=True,
                                           download=True,
                                           batch_size=self.batch_size,
                                           sampler_callback=sampler_callback,
                                           drop_last=True)  # exclude non-full batches

            is_byzantine = (worker_rank >= n_good)
            worker = self._init_worker(train_loader,
                                       is_byzantine,
                                       worker_models[worker_rank],
                                       worker_optimizers[worker_rank])
            if is_byzantine and self.attack in ['alie', 'ipm']:
                worker.configure(trainer)

            trainer.add_worker(worker)
