from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .blocks import BasicRes, LinearRes
from .layers import Conv2d, Flatten, Linear, Sequential, build_activation


def build_stem(input_size: int, width: int, act_name: str) -> nn.Module:
    """Build the first layer to inputs."""

    if input_size == 64:  # Tiny ImageNet
        conv = Conv2d(3, width, kernel_size=7, stride=4, input_size=64)
        output_size = 16
    elif input_size == 32:  # CIFAR10/100
        conv = Conv2d(3, width, kernel_size=5, stride=2, input_size=32)
        output_size = 16
    # We use the patchify stem as in ViTs for ImageNet
    elif input_size == 224:
        patch_size = round((width / 3)**.5)
        conv = Conv2d(3,
                      width,
                      kernel_size=patch_size,
                      stride=patch_size,
                      padding=0,
                      input_size=224)
        output_size = 224 // patch_size
    else:
        raise ValueError('Unsupported `input_size`!')

    activation = build_activation(act_name, dim=1, channels=width)
    stem_layer = Sequential(conv, activation)
    return stem_layer, output_size


def build_backbone(arch: str, depth: int, width: int, act_name: str,
                   **kwargs) -> nn.Module:
    """Build the backbone to extract features from image data."""
    kwargs['depth'] = depth

    backbone = []
    for _ in range(depth):
        if arch == 'mlp':
            building_block = Conv2d(width, width, kernel_size=3, **kwargs)
        elif arch == 'basic_res':
            building_block = BasicRes(width, **kwargs)
        elif arch == 'linear_res':
            building_block = LinearRes(width, **kwargs)
        else:
            raise ValueError('Unsupported `arch`!')

        activation = build_activation(act_name, dim=1, channels=width)

        backbone.append(building_block)
        backbone.append(activation)

    return Sequential(*backbone)


def build_neck(width: int, input_size: int, out_dim: int,
               act_name: str) -> nn.Module:
    """Build the nect to convert feature maps to vector features."""
    num_features = width * 2 * (input_size // 4)**2

    act1 = build_activation(act_name, dim=1, channels=width * 2)
    # The current HouseHolder implementation not support 1D data.
    act2 = build_activation('MinMax', dim=1)

    neck = Sequential(
        Conv2d(width,
               width * 2,
               kernel_size=4,
               stride=4,
               padding=0,
               input_size=input_size), act1, Flatten(),
        Linear(num_features, out_dim), act2)
    return neck


class head(nn.Linear):
    """Build the head to outputs."""
    def __init__(self, num_features: int, num_classes: int,
                 use_lln: bool) -> None:
        super(head, self).__init__(num_features, num_classes)
        self.use_lln = use_lln

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_lln:
            weight = F.normalize(self.weight, dim=1)
        x = F.linear(x, weight, self.bias)
        return x


class GloroNet(nn.Module):
    def __init__(self,
                 arch: str,
                 depth: int = 12,
                 width: int = 128,
                 input_size: int = 32,
                 num_classes: int = 10,
                 num_lc_iter: int = 10,
                 act_name: str = 'MinMax',
                 use_lln: bool = True,
                 use_batch_lipschitz: bool = True,
                 **kwargs):
        super(GloroNet, self).__init__()

        stem, feature_size = build_stem(input_size, width, act_name)
        self.stem = stem
        kwargs['input_size'] = feature_size

        self.backbone = build_backbone(arch, depth, width, act_name, **kwargs)

        out_dim = 2048 if input_size == 224 else 512
        self.neck = build_neck(width, feature_size, out_dim, act_name)
        self.head = head(out_dim, num_classes, use_lln)

        self.num_lc_iter = num_lc_iter
        self.set_num_lc_iter()

        self.use_batch_lipschitz = use_batch_lipschitz and (arch !=
                                                            'basic_res')
        if self.use_batch_lipschitz:
            x = torch.randn(1, depth * width, feature_size, feature_size)
            self.register_buffer('init_layer', x)

    def set_num_lc_iter(self, num_lc_iter: Optional[int] = None) -> None:
        if num_lc_iter is None:
            num_lc_iter = self.num_lc_iter
        for m in self.modules():
            setattr(m, 'num_lc_iter', num_lc_iter)

    def forward(self,
                x: torch.Tensor,
                return_feat: bool = False) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input image tensors in [0, 1]
            return_feat (bool): if true, only return the extracted features.

        """
        x = x.sub(.5)
        x = self.stem(x)
        x = self.backbone(x)
        x = self.neck(x)
        if return_feat:
            return x
        x = self.head(x)
        return x

    def sub_lipschitz(self,
                      disable_batch_lipschitz: bool = False) -> torch.Tensor:
        """Compute the lipschitz constant of the model except the head.

        Args:
            disable_batch_lipschitz (bool): if true, compute the lipschitz of
            sub-modules in `self.backbone` in a (much faster) batching method.
        """
        lc = self.stem.lipschitz()

        if not disable_batch_lipschitz and self.use_batch_lipschitz:
            lc = lc * self.batch_lipschitz()
        else:
            lc = lc * self.backbone.lipschitz()

        lc = lc * self.neck.lipschitz()

        return lc

    def batch_lipschitz(self):
        weight = [
            module.get_weight() for module in self.backbone
            if hasattr(module, 'weight')
        ]
        groups = len(weight)
        weight = torch.cat(weight, dim=0)
        planes = weight.shape[1]

        x = self.init_layer.data
        for _ in range(self.num_lc_iter):
            x = F.conv2d(x, weight, bias=None, padding=1, groups=groups)
            x = F.conv_transpose2d(x,
                                   weight,
                                   bias=None,
                                   padding=1,
                                   groups=groups)
            norm = x.pow(2).sum((0, 2, 3)).reshape(groups, -1).sum(1).sqrt()
            norm = norm.reshape(groups, -1).expand(groups, planes)
            x = x / norm.reshape(1, -1, 1, 1).clamp_min(1e-10)

        self.init_layer += (x - self.init_layer).detach()
        x = F.conv2d(x, weight, bias=None, padding=1, groups=groups)
        norm = x.pow(2).sum((0, 2, 3)).reshape(groups, -1).sum(1).sqrt()
        return norm.prod()
