# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
"""
MaskFormer criterion.
"""
import logging

import torch
import torch.nn.functional as F
from torch import nn
from owviscap_video.modeling import box_ops
from detectron2.utils.comm import get_world_size
from detectron2.projects.point_rend.point_features import (
    get_uncertain_point_coords_with_randomness,
    point_sample,
)
import pdb
import time
from owviscap.utils.misc import is_dist_avail_and_initialized
from torchvision.utils import save_image
import random


def dice_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(-1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)

    return loss.sum() / num_masks


dice_loss_jit = torch.jit.script(dice_loss)  # type: torch.jit.ScriptModule


def sigmoid_ce_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_masks: float,
):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    Returns:
        Loss tensor
    """
    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

    return loss.mean(1).sum() / num_masks


sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss)  # type: torch.jit.ScriptModule


def calculate_uncertainty(logits):
    """
    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
        foreground class in `classes`.
    Args:
        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
            class-agnostic, where R is the total number of predicted masks in all images and C is
            the number of foreground classes. The values are logits.
    Returns:
        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
            the most uncertain locations having the highest uncertainty score.
    """
    assert logits.shape[1] == 1
    gt_class_logits = logits.clone()
    return -(torch.abs(gt_class_logits))


class VideoSetCriterion(nn.Module):
    """This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """

    def __init__(
        self,
        num_classes,
        matcher,
        weight_dict,
        eos_coef,
        losses,
        num_points,
        oversample_ratio,
        importance_sample_ratio,
        box_loss,
        contrastive,
        contrastive_in_mask,
    ):
        """Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

        # pointwise mask loss parameters
        self.num_points = num_points
        self.oversample_ratio = oversample_ratio
        self.importance_sample_ratio = importance_sample_ratio
        self.box_loss = box_loss  # true or false
        self.contrastive = contrastive
        self.contrastive_in_mask = contrastive_in_mask

    def loss_boxes(self, outputs, targets, indices, num_boxes, openworld=False):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert "pred_boxes" in outputs
        if openworld == True:
            suffix = "_openworld"
        else:
            suffix = ""
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat(
            [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
        ).to(src_boxes.device)

        loss_bbox = F.l1_loss(
            src_boxes, target_boxes, reduction="none"
        )  # box_ops.box_xyxy_to_cxcywh()

        losses = {}
        losses["loss_bbox" + suffix] = loss_bbox.sum() / num_boxes

        for t in range(src_boxes.shape[1]):
            boxes_1 = box_ops.box_cxcywh_to_xyxy(src_boxes[:, t])
            boxes_2 = box_ops.box_cxcywh_to_xyxy(target_boxes[:, t])
            if t > 0:
                loss_giou = (
                    loss_giou
                    + 1
                    - torch.diag(box_ops.generalized_box_iou(boxes_1, boxes_2))
                )
            else:
                loss_giou = 1 - torch.diag(
                    box_ops.generalized_box_iou(boxes_1, boxes_2)
                )
        loss_giou = loss_giou / src_boxes.shape[1]
        losses["loss_giou" + suffix] = loss_giou.sum() / num_boxes
        # t1=time.time()
        # for t, s in zip(target_boxes, src_boxes):
        #     print("tgt", t)
        #     print("src", s)
        # tgt_msk=box_ops.boxes_to_mask(box_ops.box_cxcywh_to_xyxy(target_boxes)*100, (100,100), (100,100))
        # src_msk=box_ops.boxes_to_mask(box_ops.box_cxcywh_to_xyxy(src_boxes)*100, (100,100), (100,100))
        # print(losses['loss_bbox'], losses['loss_giou'])
        # save_image(tgt_msk, str(t1)+"_tgt.png", pad_value=1)
        # save_image(src_msk, str(t1)+"_src.png", pad_value=1)
        # pdb.set_trace()
        return losses

    def loss_labels(self, outputs, targets, indices, num_masks, openworld=False):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert "pred_logits" in outputs
        if openworld == True:
            suffix = "_openworld"
        else:
            suffix = ""

        src_logits = outputs["pred_logits"].float()

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)]
        ).to(src_logits.device)
        target_classes = torch.full(
            src_logits.shape[:2],
            self.num_classes,
            dtype=torch.int64,
            device=src_logits.device,
        )

        target_classes[idx] = target_classes_o
        if openworld == True:
            loss_ce = F.cross_entropy(
                src_logits[idx], target_classes[idx], self.empty_weight
            )
        else:
            loss_ce = F.cross_entropy(
                src_logits.transpose(1, 2), target_classes, self.empty_weight
            )
        losses = {"loss_ce" + suffix: loss_ce}

        return losses

    def loss_masks(self, outputs, targets, indices, num_masks, openworld=False):
        """Compute the losses related to the masks: the focal loss and the dice loss.
        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs
        if openworld == True:
            suffix = "_openworld"
        else:
            suffix = ""

        src_idx = self._get_src_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        # Modified to handle video
        target_masks = torch.cat(
            [t["masks"][i] for t, (_, i) in zip(targets, indices)]
        ).to(src_masks)

        # No need to upsample predictions as we are using normalized coordinates :)
        # NT x 1 x H x W
        src_masks = src_masks.flatten(0, 1)[:, None]
        target_masks = target_masks.flatten(0, 1)[:, None]

        with torch.no_grad():
            # sample point_coords
            point_coords = get_uncertain_point_coords_with_randomness(
                src_masks,
                lambda logits: calculate_uncertainty(logits),
                self.num_points,
                self.oversample_ratio,
                self.importance_sample_ratio,
            )
            # get gt labels
            point_labels = point_sample(
                target_masks,
                point_coords,
                align_corners=False,
            ).squeeze(1)

        point_logits = point_sample(
            src_masks,
            point_coords,
            align_corners=False,
        ).squeeze(1)

        losses = {
            "loss_mask"
            + suffix: sigmoid_ce_loss(point_logits, point_labels, num_masks),
            "loss_dice" + suffix: dice_loss(point_logits, point_labels, num_masks),
        }

        del src_masks
        del target_masks
        return losses

    def loss_contrastive(self, outputs, targets, indices, num_masks, openworld=False):

        if openworld == True:
            suffix = "_openworld"
        else:
            suffix = ""

        if self.contrastive_in_mask == True:
            queries = outputs["mask_embed"]
        else:
            queries = outputs["obj_queries"]

        queries = F.normalize(queries, dim=-1)
        array1 = list(range(queries.shape[1]))
        array2 = list(range(queries.shape[1]))
        while any(array1[i] == array2[i] for i in range(len(array1))):
            random.shuffle(array2)
        loss = -abs(queries - queries[:, array2]).mean(-1).mean(-1).sum()

        losses = {"loss_contrastive" + suffix: loss}

        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat(
            [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
        )
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat(
            [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
        )
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_masks, openworld=False):

        if self.box_loss:
            loss_map = {
                "labels": self.loss_labels,
                "boxes": self.loss_boxes,
            }
        else:

            loss_map = {
                "labels": self.loss_labels,
                "masks": self.loss_masks,
            }

        if self.contrastive == True:
            loss_map["contrastive"] = self.loss_contrastive

        assert loss in loss_map, f"do you really want to compute {loss} loss?"
        return loss_map[loss](outputs, targets, indices, num_masks, openworld=openworld)

    def forward(self, outputs, targets, openworld=False):
        """This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """

        if openworld == True:
            openworld_q = int((outputs["obj_queries"].shape[1] / 2) ** (1 / 2)) ** 2
        else:
            openworld_q = 0

        outputs_without_aux_closedworld = {
            k: v[:, openworld_q:] for k, v in outputs.items() if k != "aux_outputs"
        }
        indices_closedworld = self.matcher(outputs_without_aux_closedworld, targets)

        if openworld == True:

            outputs_without_aux_openworld = {
                k: v[:, :openworld_q] for k, v in outputs.items() if k != "aux_outputs"
            }
            indices_openworld = self.matcher(outputs_without_aux_openworld, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_masks = sum(len(t["labels"]) for t in targets)
        num_masks = torch.as_tensor(
            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
        )
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_masks)
        num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(
                self.get_loss(
                    loss,
                    outputs_without_aux_closedworld,
                    targets,
                    indices_closedworld,
                    num_masks,
                    openworld=False,
                )
            )
            if openworld == True:
                losses.update(
                    self.get_loss(
                        loss,
                        outputs_without_aux_openworld,
                        targets,
                        indices_openworld,
                        num_masks,
                        openworld=True,
                    )
                )

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):

                aux_outputs_closedworld = {
                    k: v[:, openworld_q:] for k, v in aux_outputs.items()
                }
                indices_closedworld = self.matcher(aux_outputs_closedworld, targets)

                if openworld == True:
                    aux_outputs_openworld = {
                        k: v[:, :openworld_q] for k, v in aux_outputs.items()
                    }
                    indices_openworld = self.matcher(aux_outputs_openworld, targets)

                for loss in self.losses:

                    l_dict = self.get_loss(
                        loss,
                        aux_outputs_closedworld,
                        targets,
                        indices_closedworld,
                        num_masks,
                        openworld=False,
                    )
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

                    if openworld == True:
                        l_dict = self.get_loss(
                            loss,
                            aux_outputs_openworld,
                            targets,
                            indices_openworld,
                            num_masks,
                            openworld=True,
                        )
                        l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                        losses.update(l_dict)

        return losses, indices_closedworld

    def __repr__(self):
        head = "Criterion " + self.__class__.__name__
        body = [
            "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
            "losses: {}".format(self.losses),
            "weight_dict: {}".format(self.weight_dict),
            "num_classes: {}".format(self.num_classes),
            "eos_coef: {}".format(self.eos_coef),
            "num_points: {}".format(self.num_points),
            "oversample_ratio: {}".format(self.oversample_ratio),
            "importance_sample_ratio: {}".format(self.importance_sample_ratio),
        ]
        _repr_indent = 4
        lines = [head] + [" " * _repr_indent + line for line in body]
        return "\n".join(lines)
