from typing import Dict, Union, List, Type, Callable
import torch
import torch.nn as nn
from antgine.core import flatten_module


class AbstractRegularizer(nn.Module):
    """
        Base class for regularization. DO NOT instanciate.
    """
    def __init__(self, model: nn.Module,
                 modules_attrs: Dict[str, Union[List[Type[nn.Module]],
                                                Dict[Type[nn.Module],
                                                     Callable[[nn.Module],
                                                              nn.Parameter]]]]):
        """
        :param nn.Module model: Model.
        :param dict[str, Union[List[Type[nn.Module]], Dict[Type[nn.Module], function]]] modules_attrs: Dictionary describing how to
        access regularizer's required attributes from layers in model.
        """
        super().__init__()
        self._model = model
        self._attrs = list(modules_attrs.keys())
        self._modules = list(modules_attrs[self._attrs[0]])
        assert all(map(lambda attr: set(list(modules_attrs[attr])) == set(self._modules), self._attrs))
        self._flattened_model = flatten_module(model, noflat=self._modules)
        self._modules_attrs = modules_attrs
        self._layerparams = self._get_params(self._flattened_model, self._modules, self._attrs, self._modules_attrs)

    def _get_params(self, flattened_model, modules, attrs, modules_attrs):
        layers = list(filter(lambda m: any(map(lambda t: type(m) is t, modules)), flattened_model))
        layersparams = []
        for l in layers:
            params = dict()
            for a in attrs:
                if isinstance(modules_attrs[a], dict):
                    params[a] = modules_attrs[a][type(l)](l)
                else:
                    assert hasattr(l, a)
                    params[a] = getattr(l, a)
            layersparams.append(params)
        return layersparams

    def forward(self, epoch, it):
        """
        :param int epoch: Current epoch.
        :param int it: Current iteration.
        :return: Regularization value (scalar).
        :rtype: torch.Tensor
        """
        raise NotImplementedError()
