import sys
import numpy as np
import cv2
from PIL import Image
import json
import pdb
import torch
import os
import numpy
from functools import partial
import time
from multiprocessing import Pool
import pycocotools.mask as cocomask

split = "train"
all_frames = False
open_world = True

if split == "train" and open_world == True:
    json_file = "../../data/burst/annotations/" + split + "/common_classes.json"
else:
    json_file = "../../data/burst/annotations/" + split + "/all_classes.json"


if all_frames:
    output_file = "../../data/burst/" + split + "_allframes"
else:
    output_file = "../../data/burst/" + split
if open_world:
    output_file = output_file + "_ow"
output_file = output_file + ".json"


json_data = json.load(open(json_file))
output_data = {}
if not open_world:
    output_data["categories"] = json_data["categories"]
else:
    output_data["categories"] = [{"id": 1, "name": "foreground"}]

output_data["videos"] = []
output_data["annotations"] = []
annotation_id = 0
for i, sequence in enumerate(json_data["sequences"]):
    print(i, "/", len(json_data["sequences"]))
    height = sequence["height"]
    width = sequence["width"]
    track_ids_seq = list(
        set([r for k in sequence["segmentations"] for r in list(k.keys())])
    )
    tracks_seq = {
        k: {"segmentations": [], "bboxes": [], "areas": []} for k in track_ids_seq
    }

    if all_frames:
        image_paths = "all_image_paths"
        output_data["annotations"] = []
        length = len(sequence[image_paths])

    else:
        image_paths = "annotated_image_paths"
        length = len(sequence[image_paths])
        for frame_no, segmentations_frame in enumerate(sequence["segmentations"]):
            for track_id in track_ids_seq:

                if track_id in segmentations_frame.keys():
                    current_seg = {
                        "size": [height, width],
                        "counts": segmentations_frame[track_id]["rle"],
                    }
                    current_bbox = cocomask.toBbox(current_seg).tolist()
                    current_area = int(cocomask.area(current_seg))
                else:
                    empty_image = np.asfortranarray(
                        np.zeros((height, width), dtype=np.uint8)
                    )
                    current_seg = {
                        "size": [height, width],
                        "counts": cocomask.encode(empty_image)["counts"].decode(
                            encoding="UTF-8"
                        ),
                    }
                    current_bbox = None
                    current_area = None

                tracks_seq[track_id]["segmentations"].append(current_seg)
                tracks_seq[track_id]["bboxes"].append(current_bbox)
                tracks_seq[track_id]["areas"].append(current_area)

        for track_id in tracks_seq:
            if open_world:
                category_id = 1
            else:
                category_id = sequence["track_category_ids"][track_id]

            annotation = {
                "height": height,
                "width": width,
                "length": length,
                "category_id": category_id,
                "segmentations": tracks_seq[track_id]["segmentations"],
                "bboxes": tracks_seq[track_id]["bboxes"],
                "video_id": sequence["id"],
                "iscrowd": 0,
                "id": annotation_id,
                "areas": tracks_seq[track_id]["areas"],
            }
            annotation_id += 1
            output_data["annotations"].append(annotation)

    video = {
        "width": width,
        "height": height,
        "length": length,
        "id": sequence["id"],
        "file_names": [
            os.path.join(sequence["dataset"], sequence["seq_name"], k)
            for k in sequence[image_paths]
        ],
    }
    output_data["videos"].append(video)
json.dump(output_data, open(output_file, "w"))
