# 复现Unidistill loss，因为head不同，所以只复现feature和BEV KD loss
# 两个loss分开复现，这里复现BEV loss
# 基本操作和feature loss类似,不过需要对取得的9点坐标feature进行关系矩阵计算
from typing import Tuple, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.registry import MODELS


@MODELS.register_module()
class UnidistillBEVLoss(nn.Module):
    """对feature loss的复现，理论上传入只有一层的feature

    Args:
        loss_weight (float): Weight of loss. Defaults to 1.0.
        resize_stu (bool): If True, we'll down/up sample the features of the
            student model to the spatial size of those of the teacher model if
            their spatial sizes are different. And vice versa. Defaults to
            True.
    """

    def __init__(self, loss_weight=1.0, resize_stu=True):
        super(UnidistillBEVLoss, self).__init__()
        self.loss_weight = loss_weight
        self.resize_stu = resize_stu


    # 这两个方法都要有gt_boxes_bev_coords这个参数
    def forward(self, preds_S: Union[torch.Tensor, Tuple],
                preds_T: Union[torch.Tensor, Tuple],
                gt_boxes_bev_coords: List[torch.Tensor]) -> torch.Tensor:
        """Forward computation.

        Args:
            preds_S (torch.Tensor | Tuple[torch.Tensor]): The student model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).
            preds_T (torch.Tensor | Tuple[torch.Tensor]): The teacher model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).

        Return:
            torch.Tensor: The calculated loss value.
        """
        if isinstance(preds_S, torch.Tensor):
            preds_S, preds_T = (preds_S, ), (preds_T, )

        loss = 0.

        for pred_S, pred_T in zip(preds_S, preds_T):
            size_S, size_T = pred_S.shape[2:], pred_T.shape[2:]
            if size_S[0] != size_T[0]:
                print("Warning: The feature map of student doesn't match the feature map of teacher!")
                if self.resize_stu:
                    pred_S = F.interpolate(pred_S, size_T, mode='bilinear')
                else:
                    pred_T = F.interpolate(pred_T, size_S, mode='bilinear')
            assert pred_S.shape == pred_T.shape

            # 对batch中的每个frame进行相同的操作, 自然在B这个维度上sum,
            # 不需要reduce mean,这个操作和pkd_loss吻合
            for idx in range(pred_S.shape[0]):
                loss = loss + self._single_apply_loss(pred_S[idx], pred_T[idx], gt_boxes_bev_coords[idx])

        return loss * self.loss_weight


    def _single_apply_loss(self, pred_frame_S: torch.Tensor, 
                            pred_frame_T: torch.Tensor, 
                            gt_boxes_bev_coords: torch.Tensor) -> torch.Tensor:
        h, w = pred_frame_S.shape[-2:]
        gt_boxes_bev_center = torch.mean(gt_boxes_bev_coords, dim=1).unsqueeze(1)

        gt_boxes_bev_edge_1 = torch.mean(
            gt_boxes_bev_coords[:, [0, 1], :], dim=1
        ).unsqueeze(1)
        gt_boxes_bev_edge_2 = torch.mean(
            gt_boxes_bev_coords[:, [1, 2], :], dim=1
        ).unsqueeze(1)
        gt_boxes_bev_edge_3 = torch.mean(
            gt_boxes_bev_coords[:, [2, 3], :], dim=1
        ).unsqueeze(1)
        gt_boxes_bev_edge_4 = torch.mean(
            gt_boxes_bev_coords[:, [0, 3], :], dim=1
        ).unsqueeze(1)
        gt_boxes_bev_all = torch.cat(
            (
                gt_boxes_bev_coords,
                gt_boxes_bev_center,
                gt_boxes_bev_edge_1,
                gt_boxes_bev_edge_2,
                gt_boxes_bev_edge_3,
                gt_boxes_bev_edge_4,
            ),
            dim=1,
        )
        gt_boxes_bev_all[:, :, 0] = (gt_boxes_bev_all[:, :, 0] - w / 2) / (w / 2)
        gt_boxes_bev_all[:, :, 1] = (gt_boxes_bev_all[:, :, 1] - h / 2) / (h / 2)
        gt_boxes_bev_all[:, :, [0, 1]] = gt_boxes_bev_all[:, :, [1, 0]]

        # 为了兼容grid_sample的api,需要对两者扩展batch这个维度
        pred_frame_S = pred_frame_S.unsqueeze(0)
        pred_frame_T = pred_frame_T.unsqueeze(0)
        gt_boxes_bev_all = gt_boxes_bev_all.unsqueeze(0)

        # B, C, N, 9
        pred_frame_S_sample = torch.nn.functional.grid_sample(
            pred_frame_S, gt_boxes_bev_all.float()
        )
        pred_frame_S_sample = pred_frame_S_sample.permute(0, 2, 3, 1)
        pred_frame_T_sample = torch.nn.functional.grid_sample(
            pred_frame_T, gt_boxes_bev_all.float()
        )

        # B, N, 9, C
        pred_frame_T_sample = pred_frame_T_sample.permute(0, 2, 3, 1)

        criterion = nn.L1Loss(reduce=False)

        # B*N, 9, C
        gt_boxes_sample_lidar_feature = pred_frame_S_sample.contiguous().view(
            -1, pred_frame_S_sample.shape[-2], pred_frame_S_sample.shape[-1]
        )
        gt_boxes_sample_fuse_feature = pred_frame_T_sample.contiguous().view(
            -1, pred_frame_T_sample.shape[-2], pred_frame_T_sample.shape[-1]
        )

        # 在channel维度进行norm
        gt_boxes_sample_lidar_feature = gt_boxes_sample_lidar_feature / (
            torch.norm(gt_boxes_sample_lidar_feature, dim=-1, keepdim=True) + 1e-4
        )
        gt_boxes_sample_fuse_feature = gt_boxes_sample_fuse_feature / (
            torch.norm(gt_boxes_sample_fuse_feature, dim=-1, keepdim=True) + 1e-4
        )

        # 将9, C * C, 9变成了9, 9的矩阵,维度变成B*N, 9, 9
        gt_boxes_lidar_rel = torch.bmm(
            gt_boxes_sample_lidar_feature,
            torch.transpose(gt_boxes_sample_lidar_feature, 1, 2),
        )
        gt_boxes_fuse_rel = torch.bmm(
            gt_boxes_sample_fuse_feature,
            torch.transpose(gt_boxes_sample_fuse_feature, 1, 2),
        )
        gt_boxes_lidar_rel = gt_boxes_lidar_rel.contiguous().view(
            # 恢复B*N为B, N
            gt_boxes_bev_all.shape[0],
            gt_boxes_bev_all.shape[1],
            gt_boxes_lidar_rel.shape[-2],
            gt_boxes_lidar_rel.shape[-1],
        )
        gt_boxes_fuse_rel = gt_boxes_fuse_rel.contiguous().view(
            gt_boxes_bev_all.shape[0],
            gt_boxes_bev_all.shape[1],
            gt_boxes_fuse_rel.shape[-2],
            gt_boxes_fuse_rel.shape[-1],
        )
        loss_rel = criterion(
            gt_boxes_lidar_rel, gt_boxes_fuse_rel
        )
        loss_rel = torch.mean(loss_rel, 2)
        loss_rel = torch.mean(loss_rel, 1)
        loss_rel = torch.mean(loss_rel, 1)
        loss_rel = torch.sum(loss_rel)
        loss_rel = loss_rel

        # 觉得他这个weight有点问题,已经在N这个维度mean了,不需要再除了
        return loss_rel

