import torch
import numpy as np
from antgine.regularizer import AbstractRegularizer
from antgine.modules import ScaledModule


class L2Bin(AbstractRegularizer):
    """
        BNN+ L2 binary regularizer.
    """
    def __init__(self, model, lambda_, modules_attrs={'weight': {ScaledModule: lambda l: l.module.weight},
                                                      'scaling_factor': {ScaledModule: lambda l: l.scale.scaling_factor}}):
        """
            See :meth:`antgine.regulizer.AbstractRegularizer.__init__`.
        `module_attrs` should contains 'weight' and `scaling_factor` key.
        :param float lambda_: Penalty value.
        """
        super().__init__(model=model, modules_attrs=modules_attrs)
        assert 'weight' in modules_attrs and 'scaling_factor' in modules_attrs
        self._lambda = lambda_

    def forward(self, epoch, it):
        reg = 0
        for l in self._layerparams:
            weight = l['weight']
            s = l['scaling_factor']
            scaling_factor = s.view(-1, 1, 1, 1) if len(weight.size()) == 4 else s.view(-1, 1)
            reg += torch.sum((torch.abs(weight) - scaling_factor) ** 2)
        reg *= self._lambda * np.log1p(epoch)
        return reg
