# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
import logging
import fvcore.nn.weight_init as weight_init
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import time
from detectron2.config import configurable
from detectron2.layers import Conv2d
import pdb
from owviscaptor_video.modeling.box_ops import *
from owviscaptor.modeling.transformer_decoder.maskformer_transformer_decoder import (
    TRANSFORMER_DECODER_REGISTRY,
)
import time
from .position_encoding import PositionEmbeddingSine3D
from transformers import DetrForObjectDetection
from torchvision.utils import save_image
from transformers import SamModel, SamProcessor


class SelfAttentionLayer(nn.Module):

    def __init__(
        self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(
        self,
        tgt,
        tgt_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(
            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)

        return tgt

    def forward_pre(
        self,
        tgt,
        tgt_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        tgt2 = self.norm(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(
            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
        tgt = tgt + self.dropout(tgt2)

        return tgt

    def forward(
        self,
        tgt,
        tgt_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        if self.normalize_before:
            return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos)
        return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos)


class CrossAttentionLayer(nn.Module):

    def __init__(
        self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False
    ):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(
        self,
        tgt,
        memory,
        memory_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        tgt2 = self.multihead_attn(
            query=self.with_pos_embed(tgt, query_pos),
            key=self.with_pos_embed(memory, pos),
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]

        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)
        return tgt

    def forward_pre(
        self,
        tgt,
        memory,
        memory_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        tgt2 = self.norm(tgt)
        tgt2 = self.multihead_attn(
            query=self.with_pos_embed(tgt2, query_pos),
            key=self.with_pos_embed(memory, pos),
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
        tgt = tgt + self.dropout(tgt2)

        return tgt

    def forward(
        self,
        tgt,
        memory,
        memory_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        if self.normalize_before:
            return self.forward_pre(
                tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos
            )
        return self.forward_post(
            tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos
        )


class FFNLayer(nn.Module):

    def __init__(
        self,
        d_model,
        dim_feedforward=2048,
        dropout=0.0,
        activation="relu",
        normalize_before=False,
    ):
        super().__init__()
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt):
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)
        return tgt

    def forward_pre(self, tgt):
        tgt2 = self.norm(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout(tgt2)
        return tgt

    def forward(self, tgt):
        if self.normalize_before:
            return self.forward_pre(tgt)
        return self.forward_post(tgt)


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")


class MLP(nn.Module):
    """Very simple multi-layer perceptron (also called FFN)"""

    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        num_layers,
        box_head=False,
        feature_dim=10000,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.num_box_layers = num_layers
        self.box_head = box_head
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        if self.box_head:
            self.box_layers = MLP_box(feature_dim, hidden_dim, 4, self.num_box_layers)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)

        return x


class MLP_box(nn.Module):
    """Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                nn.Linear(input_dim, input_dim // 20),
                nn.Linear(input_dim // 20, hidden_dim),
                nn.Linear(hidden_dim, output_dim),
            ]
        )
        self.num_layers = len(self.layers)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class Box_network(nn.Module):
    """cross attention + Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, hidden_dim, nheads, pre_norm):
        super().__init__()
        self.cross_attention = CrossAttentionLayer(
            d_model=hidden_dim,
            nhead=nheads,
            dropout=0.0,
            normalize_before=pre_norm,
        )
        self.mlp = MLP_box(hidden_dim, hidden_dim, 4, 3)

    def forward(self, query, features, attn_mask=None, pos=None, query_pos=None):
        output = query.unsqueeze(2).repeat(1, 1, features.shape[1], 1)
        if attn_mask is not None:
            output = self.cross_attention(
                output.permute(1, 0, 2, 3).flatten(1, 2),
                features.flatten(-2).flatten(0, 1).permute(2, 0, 1),
                memory_mask=attn_mask.permute(0, 3, 1, 2, 4).flatten(0, 2),
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=pos.flatten(-2).flatten(0, 1).permute(2, 0, 1),
                query_pos=query_pos.unsqueeze(2)
                .repeat(1, 1, features.shape[1], 1)
                .flatten(1, 2),
            )
            output = output.reshape(
                output.shape[0], query.shape[0], -1, output.shape[-1]
            ).permute(1, 0, 2, 3)
        else:
            output = self.cross_attention(
                output.permute(1, 0, 2, 3).flatten(1, 2),
                features.flatten(-2).flatten(0, 1).permute(2, 0, 1),
                memory_mask=None,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=pos.flatten(-2).flatten(0, 1).permute(2, 0, 1),
                query_pos=query_pos.unsqueeze(2)
                .repeat(1, 1, features.shape[1], 1)
                .flatten(1, 2),
            )
            output = output.reshape(
                output.shape[0], query.shape[0], -1, output.shape[-1]
            ).permute(1, 0, 2, 3)
        output = self.mlp(output)

        return output


@TRANSFORMER_DECODER_REGISTRY.register()
class VideoMultiScaleMaskedTransformerDecoder(nn.Module):

    _version = 2

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)
        if version is None or version < 2:
            # Do not warn if train from scratch
            scratch = True
            logger = logging.getLogger(__name__)
            for k in list(state_dict.keys()):
                newk = k
                if "static_query" in k:
                    newk = k.replace("static_query", "query_feat")
                if newk != k:
                    state_dict[newk] = state_dict[k]
                    del state_dict[k]
                    scratch = False

            if not scratch:
                logger.warning(
                    f"Weight format of {self.__class__.__name__} have changed! "
                    "Please upgrade your models. Applying automatic conversion now ..."
                )

    @configurable
    def __init__(
        self,
        in_channels,
        mask_classification=True,
        *,
        num_classes: int,
        hidden_dim: int,
        num_queries: int,
        nheads: int,
        dim_feedforward: int,
        dec_layers: int,
        pre_norm: bool,
        mask_dim: int,
        enforce_input_project: bool,
        # video related
        num_frames,
        box_head_in_mask,
        queries_2types=True,
        mask_sigmoid=False,
    ):
        """
        NOTE: this interface is experimental.
        Args:
            in_channels: channels of the input features
            mask_classification: whether to add mask classifier or not
            num_classes: number of classes
            hidden_dim: Transformer feature dimension
            num_queries: number of queries
            nheads: number of heads
            dim_feedforward: feature dimension in feedforward network
            enc_layers: number of Transformer encoder layers
            dec_layers: number of Transformer decoder layers
            pre_norm: whether to use pre-LayerNorm or not
            mask_dim: mask feature dimension
            enforce_input_project: add input project 1x1 conv even if input
                channels and hidden dim is identical
        """
        super().__init__()

        assert mask_classification, "Only support mask classification model"
        self.mask_classification = mask_classification

        self.num_frames = num_frames
        self.box_head_in_mask = box_head_in_mask
        self.mask_sigmoid = mask_sigmoid

        # positional encoding
        N_steps = hidden_dim // 2
        self.pe_layer = PositionEmbeddingSine3D(N_steps, normalize=True)

        # define Transformer decoder here
        self.num_heads = nheads
        self.num_layers = dec_layers
        self.transformer_self_attention_layers = nn.ModuleList()
        self.transformer_cross_attention_layers = nn.ModuleList()
        self.transformer_ffn_layers = nn.ModuleList()

        for _ in range(self.num_layers):
            self.transformer_self_attention_layers.append(
                SelfAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

            self.transformer_cross_attention_layers.append(
                CrossAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

            self.transformer_ffn_layers.append(
                FFNLayer(
                    d_model=hidden_dim,
                    dim_feedforward=dim_feedforward,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

        self.decoder_norm = nn.LayerNorm(hidden_dim)

        self.num_queries = num_queries
        # learnable query features
        self.query_feat = nn.Embedding(num_queries, hidden_dim)
        # learnable query p.e.
        self.query_embed = nn.Embedding(num_queries, hidden_dim)

        # level embedding (we always use 3 scales)
        self.num_feature_levels = 3
        self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
        self.input_proj = nn.ModuleList()
        for _ in range(self.num_feature_levels):
            if in_channels != hidden_dim or enforce_input_project:
                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
                weight_init.c2_xavier_fill(self.input_proj[-1])
            else:
                self.input_proj.append(nn.Sequential())

        # output FFNs
        if self.mask_classification:
            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.mask_resize_dim = [128, 128]
        if self.box_head_in_mask:
            self.mask_embed = MLP(
                hidden_dim,
                hidden_dim,
                mask_dim,
                3,
                box_head=True,
                feature_dim=self.mask_resize_dim[0] * self.mask_resize_dim[1],
            )
        else:
            self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
        self.queries_2types = queries_2types
        if self.queries_2types == True:
            self.sam_prompt_encoder = SamModel.from_pretrained(
                "facebook/sam-vit-base"
            ).prompt_encoder

    @classmethod
    def from_config(cls, cfg, in_channels, mask_classification):
        ret = {}
        ret["in_channels"] = in_channels
        ret["mask_classification"] = mask_classification

        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
        # Transformer parameters:
        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD

        # NOTE: because we add learnable query features which requires supervision,
        # we add minus 1 to decoder layers to be consistent with our loss
        # implementation: that is, number of auxiliary losses is always
        # equal to number of decoder layers. With learnable query features, the number of
        # auxiliary losses equals number of decoders plus 1.
        assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ

        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM

        ret["queries_2types"] = cfg.QUERIES_2TYPES
        ret["mask_sigmoid"] = cfg.MODEL.MASK_FORMER.MASK_SIGMOID

        ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
        ret["box_head_in_mask"] = cfg.MODEL.MASK_FORMER.BOX_HEAD_IN_MASK

        return ret

    def forward(self, x, mask_features, mask=None, prev_output=[]):
        t1 = time.time()
        name = str(int(time.time()))
        bt, c_m, h_m, w_m = mask_features.shape
        bs = bt // self.num_frames  # if self.training else 1
        t = bt // bs
        mask_features = mask_features.view(bs, t, c_m, h_m, w_m)
        mask_features_pos = self.pe_layer(mask_features, None)

        # x is a list of multi-scale feature
        assert len(x) == self.num_feature_levels
        src = []
        pos = []
        size_list = []

        for i in range(self.num_feature_levels):
            size_list.append(x[i].shape[-2:])
            pos.append(
                self.pe_layer(
                    x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None
                ).flatten(3)
            )
            src.append(
                self.input_proj[i](x[i]).flatten(2)
                + self.level_embed.weight[i][None, :, None]
            )

            # NTxCxHW => NxTxCxHW => (TxHW)xNxC
            _, c, hw = src[-1].shape
            pos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
            src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)

        # QxNxC
        query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
        output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)

        if self.queries_2types:  # prev_output==[]:

            n = int((query_embed.shape[0] / 2) ** (1 / 2))
            step_h = (h_m) / n
            step_w = (w_m) / n
            point_h = torch.tensor([i * step_h + step_h / 2 for i in range(n)])
            point_w = torch.tensor([i * step_w + step_w / 2 for i in range(n)])
            grid_x, grid_y = torch.meshgrid(point_h, point_w)
            input_points = (
                torch.stack(
                    [
                        torch.tensor([x, y])
                        for x, y in zip(grid_x.flatten(), grid_y.flatten())
                    ]
                )
                .repeat(bs, 1, 1)
                .to(query_embed.device)
                .unsqueeze(0)
            )

            input_labels = torch.ones_like(
                input_points[:, :, :, 0], dtype=torch.int, device=input_points.device
            )

            prompts = self.sam_prompt_encoder(
                input_boxes=None,
                input_labels=input_labels,
                input_masks=None,
                input_points=input_points,
            )
            sparse_prompts = prompts[0][:, :, :-1].squeeze(0).transpose(1, 0)  #

            output = torch.concatenate([sparse_prompts, output])
            query_embed = torch.concatenate(
                [torch.zeros_like(query_embed)[: sparse_prompts.shape[0]], query_embed]
            )

        predictions_class = []
        predictions_mask = []
        predictions_boxes = []
        obj_queries = []
        mask_embeds = []

        # prediction heads on learnable query features
        if "box_embed" in dir(self):
            # attn_mask_full necessary for box head
            outputs_class, outputs_mask, attn_mask, outputs_boxes, attn_mask_full = (
                self.forward_prediction_heads(
                    output,
                    mask_features,
                    attn_mask_target_size=size_list[0],
                    save_name="_0_" + name,
                    attn_mask=None,
                    feature_pos=mask_features_pos,
                    query_pos=query_embed,
                )
            )
            predictions_boxes.append(outputs_boxes)
        else:
            if self.box_head_in_mask:
                outputs_class, outputs_mask, attn_mask, outputs_boxes, mask_embed = (
                    self.forward_prediction_heads(
                        output, mask_features, attn_mask_target_size=size_list[0]
                    )
                )
                predictions_boxes.append(outputs_boxes)
            else:
                outputs_class, outputs_mask, attn_mask, mask_embed = (
                    self.forward_prediction_heads(
                        output, mask_features, attn_mask_target_size=size_list[0]
                    )
                )

        predictions_class.append(outputs_class)
        predictions_mask.append(outputs_mask)
        obj_queries.append(output.transpose(0, 1).clone())
        mask_embeds.append(mask_embed.clone())

        for i in range(self.num_layers):
            level_index = i % self.num_feature_levels
            attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
            # attention: cross-attention first
            output = self.transformer_cross_attention_layers[i](
                output,
                src[level_index],
                memory_mask=attn_mask,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=pos[level_index],
                query_pos=query_embed,
            )

            output = self.transformer_self_attention_layers[i](
                output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed
            )

            # FFN
            output = self.transformer_ffn_layers[i](output)

            if i == 2 and prev_output != []:

                attn_mask_full = (
                    mask.flatten(3).unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
                    < 0.5
                ).bool()
                output = prev_output

            t1 = time.time()
            if "box_embed" in dir(self):
                attn_mask_full[
                    torch.where(attn_mask_full.sum(-1) == attn_mask_full.shape[-1])
                ] = False
                (
                    outputs_class,
                    outputs_mask,
                    attn_mask,
                    outputs_boxes,
                    attn_mask_full,
                ) = self.forward_prediction_heads(
                    output,
                    mask_features,
                    attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels],
                    save_name="_" + str(i) + "_" + name,
                    attn_mask=attn_mask_full,
                    feature_pos=mask_features_pos,
                    query_pos=query_embed,
                )
                predictions_boxes.append(outputs_boxes)
            else:
                if self.box_head_in_mask:
                    (
                        outputs_class,
                        outputs_mask,
                        attn_mask,
                        outputs_boxes,
                        mask_embed,
                    ) = self.forward_prediction_heads(
                        output,
                        mask_features,
                        attn_mask_target_size=size_list[
                            (i + 1) % self.num_feature_levels
                        ],
                    )
                    predictions_boxes.append(outputs_boxes)
                else:
                    outputs_class, outputs_mask, attn_mask, mask_embed = (
                        self.forward_prediction_heads(
                            output,
                            mask_features,
                            attn_mask_target_size=size_list[
                                (i + 1) % self.num_feature_levels
                            ],
                        )
                    )

            predictions_class.append(outputs_class)
            predictions_mask.append(outputs_mask)
            obj_queries.append(output.transpose(0, 1).clone())
            mask_embeds.append(mask_embed.clone())

        assert len(predictions_class) == self.num_layers + 1
        out = {
            "obj_queries": obj_queries[-1],
            "pred_logits": predictions_class[-1],
            "pred_masks": predictions_mask[-1],
            "mask_embed": mask_embeds[-1],
        }

        if "box_embed" in dir(self) or self.box_head_in_mask:
            out["pred_boxes"] = predictions_boxes[-1]
            out["aux_outputs"] = self._set_aux_loss(
                predictions_class if self.mask_classification else None,
                predictions_mask,
                predictions_boxes,
                obj_queries,
                mask_embeds,
            )
        else:
            out["aux_outputs"] = self._set_aux_loss(
                predictions_class if self.mask_classification else None,
                predictions_mask,
                None,
                obj_queries,
                mask_embeds,
            )
        return out

    def forward_prediction_heads(
        self,
        output,
        mask_features,
        attn_mask_target_size,
        save_name="",
        attn_mask=None,
        feature_pos=None,
        query_pos=None,
    ):
        t1 = time.time()
        decoder_output = self.decoder_norm(output)
        decoder_output = decoder_output.transpose(0, 1)
        outputs_class = self.class_embed(decoder_output)
        b, t, c, mask_h, mask_w = mask_features.shape
        size_xyxy = torch.tensor(
            [mask_w, mask_h, mask_w, mask_h], device=decoder_output.device
        )

        if "box_embed" in dir(self):
            outputs_boxes = self.box_embed(
                decoder_output,
                mask_features,
                attn_mask=attn_mask,
                pos=feature_pos,
                query_pos=query_pos,
            )
            outputs_boxes = outputs_boxes.sigmoid()
            temp_boxes = box_cxcywh_to_xyxy(outputs_boxes)
            temp_boxes = torch.einsum("bqtx,x->bqtx", temp_boxes, size_xyxy)
            outputs_mask = boxes_to_mask(
                temp_boxes, (mask_h, mask_w), (mask_h, mask_w)
            )  # bqthw
            # save_image(outputs_mask[0], "mask"+save_name+".png", pad_value=1)
        else:
            mask_embed = self.mask_embed(decoder_output)
            # mask_embed = F.normalize(mask_embed, dim=-1)#*2-1

            if self.mask_sigmoid:
                outputs_mask = torch.einsum(
                    "bqc,btchw->bqthw", mask_embed, mask_features.sigmoid()
                )
            else:
                outputs_mask = torch.einsum(
                    "bqc,btchw->bqthw", mask_embed, mask_features
                )

            b, q, t, _, _ = outputs_mask.shape
            if self.box_head_in_mask:

                resize_mask = T.Resize(self.mask_resize_dim)
                resized_and_flattened_mask = (
                    resize_mask(outputs_mask.flatten(0, -3))
                    .reshape(b, q, t, self.mask_resize_dim[0], self.mask_resize_dim[1])
                    .flatten(-2)
                )

                outputs_boxes = self.mask_embed.box_layers(resized_and_flattened_mask)
                outputs_boxes = outputs_boxes.sigmoid()

        b, q, t, _, _ = outputs_mask.shape

        # NOTE: prediction is of higher-resolution
        # [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]

        attn_mask = F.interpolate(
            outputs_mask.flatten(0, 1),
            size=attn_mask_target_size,
            mode="bilinear",
            align_corners=False,
        ).view(b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])

        # must use bool type
        if "box_embed" not in dir(self):
            attn_mask = (
                attn_mask.sigmoid()
            )  # for mask head output. otherwise attn_mask is already between 0 and 1

        else:
            attn_mask_full = (
                outputs_mask.flatten(3).unsqueeze(1).repeat(1, self.num_heads, 1, 1, 1)
                < 0.5
            ).bool()

        # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
        attn_mask = (
            attn_mask.flatten(2)
            .unsqueeze(1)
            .repeat(1, self.num_heads, 1, 1)
            .flatten(0, 1)
            < 0.5
        ).bool()

        attn_mask = attn_mask.detach()

        if "box_embed" in dir(self):
            return outputs_class, outputs_mask, attn_mask, outputs_boxes, attn_mask_full

        if self.box_head_in_mask:
            return outputs_class, outputs_mask, attn_mask, outputs_boxes, mask_embed
        else:
            return outputs_class, outputs_mask, attn_mask, mask_embed

    @torch.jit.unused
    def _set_aux_loss(
        self, outputs_class, outputs_seg_masks, outputs_boxes, obj_queries, mask_embeds
    ):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        if outputs_boxes != None:
            return [
                {"pred_logits": a, "pred_boxes": b, "obj_queries": c, "mask_embed": d}
                for a, b, c, d in zip(
                    outputs_class[:-1],
                    outputs_boxes[:-1],
                    obj_queries[:-1],
                    mask_embeds[:-1],
                )
            ]
        if self.mask_classification:
            return [
                {"pred_logits": a, "pred_masks": b, "obj_queries": c, "mask_embed": d}
                for a, b, c, d in zip(
                    outputs_class[:-1],
                    outputs_seg_masks[:-1],
                    obj_queries[:-1],
                    mask_embeds[:-1],
                )
            ]

        else:
            return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
