from typing import Type, List, Dict, Tuple
import logging
import torch
import torch.nn as nn
from antgine.core import flatten_module
from antgine.callback import Callback
from antgine.modules.pruning import Prunable


class NetworkSlimmingPruningCallback(Callback):
    def __init__(self, model: nn.Module, global_pruning_ratio: float,
                 prune_at_epoch: int):
        """
        :param torch.nn.Module model: Model.
        :param float global_pruning_ratio: Pruning ratio with respect to all prunable filter via BatchNorm.
        :param int prune_at_epoch: At which epoch pruning happens.
        """
        self._model = model
        self._global_pruning_ratio = global_pruning_ratio
        self._prune_at_epoch = prune_at_epoch

        self._batchnorm_prunable_modules = list(map(lambda m: (m.module,
                                                               torch.ones(m.module.weight.size(0), device=m.module.weight.device)),
                                                    filter(lambda m: isinstance(m, Prunable) and
                                                                                       (isinstance(m.module, nn.BatchNorm2d) or
                                                                                        isinstance(m.module, nn.BatchNorm1d)),
                                                     flatten_module(self._model, noflat=[Prunable]))))

    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]):
        if epoch == self._prune_at_epoch:
            gamma_ranking: List[Tuple[int, int, float]] = []
            for i, (bn, mask) in enumerate(self._batchnorm_prunable_modules):
                gamma_ranking += [(i, k, bn.weight.data[k].item()) for k in range(bn.weight.size(0))]
            gamma_ranking.sort(key=lambda e: e[2])
            for j in range(int(len(gamma_ranking) * self._global_pruning_ratio)):
                i, k, v = gamma_ranking[j]
                logging.info('Pruning bn %d, filter %d, val %f' % (i, k, v))
                bn, mask = self._batchnorm_prunable_modules[i]
                mask[k] = 0.  # Put mask index to 0 to mask gradients
                bn.weight.data[k] = 0.
                if bn.bias is not None:
                    bn.bias.data[k] = 0.
            logging.info('Pruned %d out of %d prunable filter' % (int(len(gamma_ranking) * self._global_pruning_ratio), len(gamma_ranking)))

    def on_backward_end(self, epoch: int, i: int, xs: torch.Tensor,
                        ys: torch.Tensor, outputs: torch.Tensor, loss: torch.Tensor):
        if epoch > self._prune_at_epoch:
            for bn, mask in self._batchnorm_prunable_modules:
                bn.weight.grad *= mask
                if bn.bias is not None:
                    bn.bias.grad *= mask

    def on_optimizer_step_end(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                              outputs: torch.Tensor, loss: torch.Tensor):
        if epoch > self._prune_at_epoch:
            for bn, mask in self._batchnorm_prunable_modules:
                bn.weight.data *= mask
                if bn.bias is not None:
                    bn.bias.data *= mask
