from functools import partial
from multiprocessing import Pool
import sys
import pdb
import os
from densecap_eval import DenseCapEvaluator
import external.TrackEval.trackeval as trackeval
from external.mots_tools.mots_vis.visualize_mots_box import (
    load_seqmap,
    process_sequence,
)
from detectron2.evaluation import DatasetEvaluator
from mask2former_video.utils import file_helper
import json
from collections import namedtuple
import torch

TrackElement = namedtuple(
    "TrackElement",
    ["t", "box", "track_id", "class_", "score", "mask", "caption", "image_id"],
)


class CHOTAEvaluator(DatasetEvaluator):
    def __init__(
        self, output_dir=".", eval_cap=True, dataset="burst", eval_resume=False
    ):  # dataset: burst or vidstg
        self.output_dir = output_dir
        self.seqmap = []
        self.seqmap_filename = self.output_dir + "sequences.seqmap"
        self.evaluator = DenseCapEvaluator(
            special_token_list=[], output_dir=self.output_dir
        )
        self.dataset = dataset  # burst or vidstg
        self.eval_resume = eval_resume

    def process(self, inputs, outputs):

        if self.dataset == "vidstg":
            mask = None

        for inp, ops in zip(inputs, outputs):

            video_tracks = []
            video_gt_tracks = []
            h, w = inp["height"], inp["width"]

            temp_name = inp["file_names"][0]
            if temp_name.find(".mp4") > -1:
                vid_name = (
                    temp_name.split("/")[-2]
                    + "_"
                    + temp_name.split("/")[-1].replace(".mp4", "")
                )
                image_ids = [op["image_id"] for op in ops]
            else:
                vid_name = temp_name.split("/")[-3] + "_" + temp_name.split("/")[-2]
                image_ids = [
                    filenames.split("/")[-1] for filenames in inp["file_names"]
                ]

            temp = []
            for op_i, op in enumerate(ops):

                if op["target_boxes"].sum() != 0:
                    if len(op["boxes"].shape) > 2:
                        boxes = op["boxes"].squeeze(1)
                    else:
                        boxes = op["boxes"]
                    if -1 not in op["caps"]:
                        temp.append(
                            {
                                "scores": op["scores"],
                                "boxes": boxes,
                                "text": op["caps"],
                                "target_boxes": op["target_boxes"],
                                "target_text": op["target_texts"],
                                "image": None,
                                "img_id": None,
                                "vid_name": vid_name,
                            }
                        )

                segments_info_per_frame = []
                gt_per_frame = []
                for obj_i, (obj_box, obj_score, obj_caption, obj_id) in enumerate(
                    zip(op["boxes"], op["scores"], op["caps"], op["obj_ids"])
                ):

                    box_list = obj_box.cpu().tolist()

                    box_list[1] = box_list[1] * h / inp["image"][0].shape[-2]
                    box_list[3] = box_list[3] * h / inp["image"][0].shape[-2]
                    box_list[0] = box_list[0] * w / inp["image"][0].shape[-1]
                    box_list[2] = box_list[2] * w / inp["image"][0].shape[-1]

                    if self.dataset == "burst":
                        mask = None  # op["masks"]
                    segments_info_per_frame.append(
                        TrackElement(
                            t=op["frame"],
                            box=[box_list[0], box_list[1], box_list[2], box_list[3]],
                            track_id=obj_id,
                            class_=1,
                            score=obj_score.item(),
                            mask=mask,
                            caption=obj_caption,
                            image_id=image_ids[op_i],
                        )
                    )
                for obj_i, (obj_box, obj_caption, obj_id) in enumerate(
                    zip(op["target_boxes"], op["target_texts"], op["gt_ids"])
                ):
                    if obj_box.sum() > 0.0:
                        if self.dataset == "burst":
                            mask = None  # op["masks"]
                        if obj_box.cpu().tolist() not in [
                            te.box for te in gt_per_frame
                        ]:
                            box_list = obj_box.cpu().tolist()

                            box_list[1] = box_list[1] * h / inp["image"][0].shape[-2]
                            box_list[3] = box_list[3] * h / inp["image"][0].shape[-2]
                            box_list[0] = box_list[0] * w / inp["image"][0].shape[-1]
                            box_list[2] = box_list[2] * w / inp["image"][0].shape[-1]
                            gt_per_frame.append(
                                TrackElement(
                                    t=op["frame"],
                                    box=obj_box.cpu().tolist(),
                                    track_id=obj_id,
                                    class_=1,
                                    score=1.0,
                                    mask=mask,
                                    caption=obj_caption,
                                    image_id=image_ids[op_i],
                                )
                            )

                if segments_info_per_frame != []:
                    video_tracks.append(segments_info_per_frame)
                if gt_per_frame != []:
                    video_gt_tracks.append(gt_per_frame)

            if not os.path.exists(self.output_dir + "/cap_jsons/"):
                os.mkdir(self.output_dir + "/cap_jsons/")
            torch.save(temp, self.output_dir + "/cap_jsons/" + vid_name + ".pkl")
            file_helper.export_tracking_result_in_kitti_format(
                vid_name,
                video_tracks,
                out_folder=self.output_dir + "/pred_txt",
            )

            file_helper.export_tracking_result_in_kitti_format(
                vid_name,
                video_gt_tracks,
                out_folder=self.output_dir + "/gt_txt",
            )

        return

    def evaluate(self):

        print("Evaluating...")

        self.seqmap = []
        for gt_file in os.listdir(self.output_dir + "/gt_txt"):
            if not (gt_file.startswith(".") or gt_file.startswith("_")):
                vid_name = gt_file.replace(".txt", "")
                tracks_gt = open(
                    self.output_dir + "/gt_txt/" + gt_file, "r"
                ).readlines()
                tracks_pred = open(
                    self.output_dir + "/pred_txt/" + gt_file, "r"
                ).readlines()
                self.seqmap.append(
                    [
                        vid_name,
                        "empty",
                        str(tracks_gt[0].split(" ")[0]),
                        str(tracks_gt[-1].split(" ")[0]),
                    ]
                )

                cap_file = self.output_dir + "/cap_jsons/" + vid_name
                cap_results_list = torch.load(cap_file + ".pkl")
                for cap_results in cap_results_list:
                    self.evaluator.add_result(**cap_results)

        with open(self.seqmap_filename, "w") as f:
            for line in self.seqmap:
                print(line[0], line[1], line[2], line[3], file=f)

        if self.dataset == "vidstg":
            if sum([len(k) for k in self.evaluator.pred_boxes]) > 0:
                results_cap = self.evaluator.evaluate()
            else:
                results_cap = []
        elif self.dataset == "burst":
            results_cap = []

        results = eval(
            output_dir=self.output_dir,
            caps=results_cap,
            seqmap_file=self.seqmap_filename,
        )
        if results_cap != []:
            print(
                "MAP: {:.3f} DET_MAP: {:.3f}, CapA: {:.3f}".format(
                    results_cap["map"],
                    results_cap["detmap"],
                    results["foreground"]["CapA"],
                )
            )

        for cls in results:
            for met in results[cls]:
                results[cls][met] = float(results[cls][met])
        return results


def eval(output_dir="", caps=[], seqmap_file=None):

    # Command line interface:
    default_eval_config = trackeval.Evaluator.get_default_eval_config()
    default_dataset_config = trackeval.datasets.VidSTG.get_default_dataset_config()

    default_dataset_config["TRACKERS_FOLDER"] = output_dir
    default_dataset_config["GT_FOLDER"] = output_dir + "gt_txt/"
    default_dataset_config["SEQMAP_FILE"] = seqmap_file  # output_dir+"sequences.seqmap"
    default_dataset_config["TRACKERS_TO_EVAL"] = ["pred_txt"]
    default_metrics_config = {"METRICS": ["HOTA", "CLEAR", "Identity"]}
    config = {
        **default_eval_config,
        **default_dataset_config,
        **default_metrics_config,
    }  # Merge default configs
    eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
    dataset_config = {
        k: v for k, v in config.items() if k in default_dataset_config.keys()
    }
    metrics_config = {
        k: v for k, v in config.items() if k in default_metrics_config.keys()
    }
    # Run code
    evaluator = trackeval.Evaluator(eval_config)
    dataset_list = [trackeval.datasets.VidSTG(dataset_config)]

    metrics_list = []
    for metric in [
        trackeval.metrics.HOTA,
        trackeval.metrics.CLEAR,
        trackeval.metrics.Identity,
    ]:
        if metric.get_name() in metrics_config["METRICS"]:
            metrics_list.append(metric())
    if len(metrics_list) == 0:
        raise Exception("No metrics selected for evaluation")
    results = evaluator.evaluate(dataset_list, metrics_list, caps=caps)
    return results


def vis(tracks_folder, output_folder, image_folder, seqmap_filename, dataset, maxf=50):

    seqmaps, max_frames = load_seqmap(seqmap_filename)
    for k in max_frames.keys():
        max_frames[k] = maxf

    for seqmap in seqmaps:
        process_sequence(
            seqmap,
            tracks_folder=tracks_folder,
            img_folder=image_folder,
            output_folder=output_folder,
            max_frames=max_frames,
        )
