from typing import Tuple, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmengine.model.weight_init import xavier_init
from mmengine.model import BaseModule
from mmengine.logging import MMLogger
from mmengine.runner.checkpoint import CheckpointLoader

from mmrazor.registry import MODELS


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y)
        return y


@MODELS.register_module()
class BEVQueryGuidedDeformableMultiLayerAttentionTransferLoss(BaseModule):
    """

    Args:
        loss_weight (float): Weight of loss. Defaults to 1.0.
        resize_stu (bool): If True, we'll down/up sample the features of the
            student model to the spatial size of those of the teacher model if
            their spatial sizes are different. And vice versa. Defaults to
            True.
    """

    def __init__(self, encoder=None, loss_weight=100.0, resize_stu=True, query_dims=256, embed_dims=512, \
                 init_cfg=None, all_mask=False, activation_mask=False, loss_type='ATTN'):
        super(BEVQueryGuidedDeformableMultiLayerAttentionTransferLoss, self).__init__(init_cfg=init_cfg)
        self.loss_weight = loss_weight
        self.resize_stu = resize_stu
        self.query_proj = nn.Linear(query_dims, embed_dims)
        self.encoder = build_transformer_layer_sequence(encoder)
        self.channel_selayer = SELayer(channel=embed_dims)
        self.spatial_mask_conv = nn.Conv2d(in_channels=embed_dims, out_channels=1, kernel_size=1)
        self.grad_already_freezed = False
        self.init_weights()
        self.all_mask = all_mask
        self.activation_mask = activation_mask

        # loss_type can choose from L1 L2 SM-L1 ATTN
        self.loss_type = loss_type

    def init_weights(self):
        logger = MMLogger.get_current_instance()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for {self.__class__.__name__}, training from scratch')
            xavier_init(self.query_proj, distribution='uniform', bias=0.)
            nn.init.xavier_uniform_(self.spatial_mask_conv.weight)
            if self.spatial_mask_conv.bias is not None:
                nn.init.zeros_(self.spatial_mask_conv.bias)
        else:
            # can load pretrained mask generation network
            assert 'checkpoint' in self.init_cfg, f'Only support specify `Pretrained` in `init_cfg` in {self.__class__.__name__}'
            ckpt = CheckpointLoader.load_checkpoint(
                self.init_cfg.checkpoint, logger=logger, map_location='cpu'
            )

            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            else:
                _state_dict = ckpt

            self.load_state_dict(_state_dict, False)

    
    def freeze_grad(self):
        for param in self.parameters():
            param.requires_grad = False
        self.grad_already_freezed = True


    def forward(self, preds_S: Union[torch.Tensor, Tuple],
                preds_T: Union[torch.Tensor, Tuple],
                mask_learning_stopped: bool,
                bev_queries=None) -> torch.Tensor:
        """Forward computation.

        Args:
            preds_S (torch.Tensor | Tuple[torch.Tensor]): The student model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).
            preds_T (torch.Tensor | Tuple[torch.Tensor]): The teacher model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).
            bev_queries的维度是(H*W, C)

        Return:
            torch.Tensor: The calculated loss value.
        """
        if isinstance(preds_S, torch.Tensor):
            preds_S = (preds_S, )
        
        if isinstance(preds_T, torch.Tensor):
            preds_T = (preds_T, )
        
        B, C, H, W = preds_T[0].shape

        if bev_queries is None:
            bev_queries = self.bev_queries.weight.to(preds_T[0].dtype).to(preds_T[0].device)
        else:
            bev_queries = bev_queries.detach()
        
        mask_channel = None
        if self.all_mask:
            mask_spatial = torch.ones(B, 1, H, W, dtype=preds_T[0].dtype, device=preds_T[0].device)
        elif self.activation_mask:
            mean_reduced = torch.mean(preds_T[0], dim=1, keepdim=True)
            # Compute the min and max values for normalization
            min_vals = torch.amin(mean_reduced, dim=(1, 2, 3), keepdim=True)
            max_vals = torch.amax(mean_reduced, dim=(1, 2, 3), keepdim=True)

            # Normalize to [0, 1]
            mask_spatial = (mean_reduced - min_vals) / (max_vals - min_vals)
        else:    
            masks = self.query_proj(bev_queries)

            masks = self.encoder(preds_T[0:1], masks, H, W, C)
            # B, H, W, C -> B, C, H, W
            masks = masks.reshape(B, H, W, C).permute(0, 3, 1, 2)
            # B, 1, H, W
            mask_spatial = self.spatial_mask_conv(masks)
            mask_spatial = torch.sigmoid(mask_spatial)
            # B, C
            mask_channel = self.channel_selayer(masks)
            mask_channel = torch.ones_like(mask_channel)

        if mask_learning_stopped:
            if not self.grad_already_freezed:
                self.freeze_grad()

        loss = torch.tensor(0.0)

        # 需要返回的几个list
        original_features = []
        for index, (pred_S, pred_T) in enumerate(zip(preds_S, preds_T)):
            if index == 0:
                pred_S = pred_S.detach()
                pred_T = pred_T.detach()
                size_S, size_T = pred_S.shape[2:], pred_T.shape[2:]
                if size_S[0] != size_T[0]:
                    print("Warning: The feature map of student doesn't match the feature map of teacher!")
                    if self.resize_stu:
                        pred_S = F.interpolate(pred_S, size_T, mode='bilinear')
                    else:
                        pred_T = F.interpolate(pred_T, size_S, mode='bilinear')
                # assert pred_S.shape == pred_T.shape

                # if apply attention transfer
                if self.loss_type != 'ATTN':
                    loss = loss + self.masked_loss(pred_S, pred_T, mask_spatial, mask_channel)
                else:
                    # other loss type
                    loss = loss + self.masked_attention_transfer(pred_S, pred_T, mask_spatial, mask_channel)

        result = [loss * self.loss_weight, preds_T[0] if len(preds_T)==1 else preds_T, mask_spatial, mask_channel]

        return result
    

    def masked_loss(self, 
                    student_featuremap: torch.Tensor,
                    teacher_featuremap: torch.Tensor,
                    spatial_mask: torch.Tensor,
                    channel_mask: torch.Tensor) -> torch.Tensor:
        B, C, H, W = teacher_featuremap.shape
        assert teacher_featuremap.shape == student_featuremap.shape, "Feature maps must have the same shape"

        # Calculate the difference or absolute difference based on the loss type
        if self.loss_type == 'L1':
            diff = torch.abs(student_featuremap - teacher_featuremap)
        elif self.loss_type == 'SM-L1':
            diff = F.smooth_l1_loss(student_featuremap, teacher_featuremap, reduction='none')
        else:  # default to l2 loss
            diff = (student_featuremap - teacher_featuremap).pow(2)

        # Apply spatial and channel masks
        spatial_mask_expanded = spatial_mask
        diff_masked = diff * spatial_mask_expanded

        # Aggregate the masked differences according to the loss type
        if self.loss_type in ['L1', 'SM-L1']:
            # For L1 and Smooth-L1, we just need the mean of the absolute differences
            final_loss = torch.mean(diff_masked)
        else:  # for l2 loss, take the mean and then the square root for RMSE
            mse_loss = torch.mean(diff_masked)
            final_loss = torch.sqrt(mse_loss)

        return final_loss


    def masked_attention_transfer(self, \
        student_featuremap, teacher_featuremap, spatial_mask, channel_mask):
        mask_student = student_featuremap * spatial_mask
        mask_teacher = teacher_featuremap * spatial_mask
        mask_student_attn = self._at(mask_student)
        mask_teacher_attn = self._at(mask_teacher)
        factor = 350.0
        return ((mask_student_attn - mask_teacher_attn)*factor).pow(2).mean()
    
    def _at(self, x):
        return F.normalize(x.pow(2).mean(1).view(x.shape[0], -1))


    def mask_featuremap(self, softmax_spatial_mask, softmax_channel_mask, featuremaps):
        softmax_spatial_mask = softmax_spatial_mask
        softmax_channel_mask = softmax_channel_mask.unsqueeze(2).unsqueeze(3)
        
        featuremaps[0] = featuremaps[0].detach()
        masked_featuremap = featuremaps[0] * softmax_spatial_mask * softmax_channel_mask

        return masked_featuremap
