import torch.nn.functional as F
import torch.nn as nn

import torch


class CompactConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(CompactConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
                                            dilation, groups, bias)

        # Weight and bias are defined in parent class.
        flatten_weight = self.weight.view(self.weight.data.size(0), -1)
        if bias:
            self.compact_weight = nn.Parameter(torch.Tensor(flatten_weight.shape[0], flatten_weight.shape[1] + 1))
            self.compact_weight[:, :-1].view(out_channels, in_channels,
                                             kernel_size, kernel_size).data.copy_(self.weight.data)
            self.compact_weight[:, -1].data.copy_(self.bias.data)
        else:
            self.compact_weight = nn.Parameter(torch.Tensor(flatten_weight.shape[0], flatten_weight.shape[1]))
            self.compact_weight.view(out_channels, in_channels, kernel_size, kernel_size).data.copy_(self.weight.data)

        delattr(self, "weight")
        delattr(self, "bias")

    def forward(self, inputs):
        if self.compact_weight.shape[-1] != self.in_channels * self.kernel_size[0] * self.kernel_size[1]:
            weight = self.compact_weight[:, :-1].view(self.out_channels, self.in_channels, *self.kernel_size),
            bias = self.compact_weight[:, -1]
            return F.conv2d(inputs, weight[0], bias, self.stride, self.padding, self.dilation, self.groups)
        else:
            weight = self.compact_weight.view(self.out_channels, self.in_channels, *self.kernel_size)
            return F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation, self.groups)

    def extra_repr(self):
        s = ("{in_channels}, {out_channels}, kernel_size={kernel_size}"
             ", stride={stride}")
        if self.padding != (0,) * len(self.padding):
            s += ", padding={padding}"
        if self.dilation != (1,) * len(self.dilation):
            s += ", dilation={dilation}"
        if self.output_padding != (0,) * len(self.output_padding):
            s += ", output_padding={output_padding}"
        if self.groups != 1:
            s += ", groups={groups}"
        if self.padding_mode != "zeros":
            s += ", padding_mode={padding_mode}"
        return s.format(**self.__dict__)
