import torch
import torch.nn as nn
import torch.nn.functional as F
from .functions.fakequantn import fakequantn
from .fakequantn_conv2d import FakeQuantNConv2d

class FakeQuantNBatchNormConv2d(FakeQuantNConv2d):
    def __init__(self, in_channel, out_channel, kernel_size, nbits, bn_stats_freeze=False, momentum=0.1, eps=1e-5, **kwargs):
        """
        :param list args: args for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        :param int nbits: number of quantization bits.
        :param bool bn_stats_freeze: Freeze BatchNorm stats or continue computing moving averages.
        :param float momentum: Momentum for moving averages.
        :param float eps: Epsilon value preventing from division by zero.
        :param dict[str, any] kwargs: kwargs for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        """
        super().__init__(in_channel, out_channel, kernel_size, nbits=nbits, **kwargs)
        self._bn_stats_freeze = bn_stats_freeze
        self._momentum = momentum
        self._eps = eps
        self.gamma = nn.Parameter(torch.ones(out_channel))
        self.beta = nn.Parameter(torch.zeros(out_channel))
        self.register_buffer('running_mean', torch.zeros(out_channel))
        self.register_buffer('running_var', torch.ones(out_channel))
        

    @property
    def bn_stats_freeze(self):
        return self._bn_stats_freeze

    @bn_stats_freeze.setter
    def bn_stats_freeze(self, val):
        self._bn_stats_freeze = val

    def _ema(self, tensor, val):
        tensor.copy_(tensor * (1.0 - self._momentum) + val * self._momentum)

    def forward(self, inputs):
        if self.training and not self._bn_stats_freeze:
            out = F.conv2d(inputs, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
            mean = out.mean(dim=(0, 2, 3))
            var = out.var(dim=(0, 2, 3))
            self._ema(self.running_mean, mean.data)
            self._ema(self.running_var, var.data)
        else:
            mean = self.running_mean
            var = self.running_var
        std = torch.sqrt(var + self._eps)
        f = (lambda x: x) if not self.quantize else self._qfunc
        return F.conv2d(inputs, f((self.weight * self.gamma.view(-1, 1, 1, 1)) / std.view(-1, 1, 1, 1)),
                        self.beta + (self.gamma * ((self.bias if self.bias is not None else 0) - mean)) / std,
                        self.stride, self.padding, self.dilation, self.groups)
