import json
import os
import re
import subprocess

import cv2
import numpy as np
import torch
import torchvision
from visualization_utils import gen_bounding_box, gen_segmentation_view
from scipy.spatial.transform import Rotation as R, Slerp

SCANNET_PATH = "./"

def get_iou(a, b, epsilon=1e-5):
    """Given two boxes `a` and `b` defined as a list of four numbers:
            [x1,y1,x2,y2]
        where:
            x1,y1 represent the upper left corner
            x2,y2 represent the lower right corner
        It returns the Intersect of Union score for these two boxes.

    Args:
        a:          (list of 4 numbers) [x1,y1,x2,y2]
        b:          (list of 4 numbers) [x1,y1,x2,y2]
        epsilon:    (float) Small value to prevent division by zero

    Returns:
        (float) The Intersect of Union score.
    """
    # COORDINATES OF THE INTERSECTION BOX
    x1 = max(a[0], b[0])
    y1 = max(a[1], b[1])
    x2 = min(a[2], b[2])
    y2 = min(a[3], b[3])

    # AREA OF OVERLAP - Area where the boxes intersect
    width = x2 - x1
    height = y2 - y1
    # handle case where there is NO overlap
    if (width < 0) or (height < 0):
        return 0.0
    area_overlap = width * height

    # COMBINED AREA
    area_a = (a[2] - a[0]) * (a[3] - a[1])
    if b[2] - b[0] > 1000 or b[3] - b[1] > 1000:
        return 0.0
    area_b = (b[2] - b[0]) * (b[3] - b[1])
    area_combined = area_a + area_b - area_overlap

    # RATIO OF AREA OF OVERLAP OVER COMBINED AREA
    iou = area_overlap / (area_combined + epsilon)
    return iou


def convert_from_uvd(u, v, d, cam_mat):
    """
    Given (u, v) the pixel location, d the depth, and cam_mat camera intrinsics
    returns the pixel location in world space (x, y, z)
    Args:
        u: (float) x coordinate of pixel
        v: (float) y coordinate of pixel
        d: (float) depth of pixel
        cam_mat: (3x3 matrix) camera intrinsics

    Returns:
        a numpy array containing (x, y, z) world coordinate of pixel
    """
    fx = cam_mat[0][0]
    fy = cam_mat[1][1]
    cx = cam_mat[0][2]
    cy = cam_mat[1][2]
    x_over_z = (u - cx) / fx
    y_over_z = (v - cy) / fy
    z = d / np.sqrt(1.0 + x_over_z**2 + y_over_z**2)
    x = x_over_z * z
    y = y_over_z * z
    return np.array([x, y, z])


def gen_gt_tracklet(cam_int, cam_exts, frame_ind, mask, depth, sampled_num=5000):
    """
    Generate groundtruth tracklet given camera intrinsics, camera extrincis
    Args:
        cam_int: (3x3 matrix) camera intrinsics
        cam_exts: (Nx3x4 matrix) camera extrinsics (R and T) of the tracklet
        frame_ind: (int) frame index for mask and depth
        mask: (HxW numpy 0-1 array) image masks for the frist frame
        depth: (HxW numpy array) image depth of the first frame (scaled vs. camera extrinsics)
        sampled_num: (int) number of points we sampled to generate the tracklet

    Returns:
        img_obj_centers: (Nx2) numpy array containing (x, y) location of the object centers
        img_sampled_coords: (Nxsampled_numx2) numpy array of pixel coordinate (x, y) in image plane
    """
    masked_depth = depth[mask]
    masked_coord = mask.nonzero()
    sampled_num = (
        sampled_num if sampled_num < masked_depth.shape[0] else masked_depth.shape[0]
    )
    indices = np.random.choice(masked_depth.shape[0], sampled_num, replace=False)
    selected_depth = masked_depth[indices]
    # convert indices to (x, y) order
    selected_coord = (masked_coord[1][indices], masked_coord[0][indices])
    world_coords = []
    for ind in range(selected_depth.shape[0]):
        u = selected_coord[0][ind]
        v = selected_coord[1][ind]
        d = selected_depth[ind]
        cam_coord = convert_from_uvd(u, v, d, cam_int)
        world_coord = np.linalg.inv(cam_exts[frame_ind]) @ np.hstack((cam_coord, 1))
        world_coords.append(world_coord)
    world_coords = np.array(world_coords)
    world_obj_center = np.mean(world_coords, axis=0)
    img_obj_centers = []
    img_sampled_coords = []
    for ind in range(cam_exts.shape[0]):
        cam_ext = cam_exts[ind]

        img_coord = (cam_int @ cam_ext @ world_coords.T).T
        indices = np.logical_and(
            (cam_ext @ world_coords.T).T[:, 2] > 0, img_coord[:, 2] != 0
        )
        img_coord = img_coord[:, :2][indices] / img_coord[:, 2][indices][:, None]
        img_sampled_coords.append(img_coord)

        img_obj_center = cam_int @ cam_ext @ world_obj_center
        img_obj_center = img_obj_center[:2] / img_obj_center[2]
        img_obj_centers.append(img_obj_center)
    return img_obj_centers, img_sampled_coords


def gen_center(cam_int, cam_exts, frame_ind, pt, depth):
    """
    Generate groundtruth tracklet given camera intrinsics, camera extrincis
    Args:
        cam_int: (3x3 matrix) camera intrinsics
        cam_exts: (Nx3x4 matrix) camera extrinsics (R and T) of the tracklet
        frame_ind: (int) frame index for mask and depth
        pt: (numpy array) x, y indices of point pt
        depth: (float) image depth of point pt

    Returns:
        img_obj_centers: (Nx2) numpy array containing (x, y) location of the object centers
        img_sampled_coords: (Nxsampled_numx2) numpy array of pixel coordinate (x, y) in image plane
    """
    u = pt[0]
    v = pt[1]
    d = depth
    cam_coord = convert_from_uvd(u, v, d, cam_int)
    world_obj_center = np.linalg.inv(cam_exts[frame_ind]) @ np.hstack((cam_coord, 1))
    img_obj_centers = []
    for ind in range(cam_exts.shape[0]):
        cam_ext = cam_exts[ind]
        img_obj_center = cam_int @ cam_ext @ world_obj_center
        img_obj_center = img_obj_center[:2] / img_obj_center[2]
        img_obj_centers.append(img_obj_center)
    return img_obj_centers


def get_maskrcnn_model():
    """
    helper function to retrieve torchvision version of maskrcnn model
    """
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model


def get_prediction(pred, threshold):
    """
    helper function to retrieve maskrcnn prediction result over threshold
    """
    pred_score = pred["scores"].detach().numpy()
    pred_class = pred["labels"].detach().numpy()
    selected_ind = [
        ind for ind in range(len(pred_score)) if (pred_score[ind] > threshold)
    ]
    masks = pred["masks"].detach().numpy()
    if len(masks):
        masks = np.stack(masks)
        masks = masks > 0.5
    pred_boxes = [
        [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
        for bbox in list(pred["boxes"].detach().numpy())
    ]
    masks = np.take(masks, selected_ind, 0)
    pred_boxes = np.take(pred_boxes, selected_ind, 0)
    pred_class = np.take(pred_class, selected_ind, 0)
    return masks, pred_boxes, pred_class


def get_mask_for_first_image(img_path, model, mask_score=0.8, visualize=False):
    """
    helper function to retrieve the most confident mask

    Args:
        img_path: (string) image file absolute path
        model: pytorch maskrcnn model from torchvision
        mask_score: (float) minimum score for a valid mask
        visualize: (boolean) whether to display the visualization

    Returns:
        most confident mask from MaskRCNN
    """
    img = cv2.imread(img_path)
    img = cv2.resize(img, (640, 480))
    img_torch = torch.from_numpy(img / 255.0).permute(2, 0, 1).float()
    model_output = model([img_torch])
    masks, boxes, pred_cls = get_prediction(model_output[0], mask_score)
    if visualize:
        img_segmentation = np.copy(img)
        # for ind in range(masks.shape[0]):
        img_segmentation = gen_segmentation_view(
            img_segmentation, boxes, pred_cls, masks
        )
    # TEMP: just pick the most confident mask for now
    return np.squeeze(masks[:3], axis=1), pred_cls[:3]


def accumulate_pairwise_pose(poses, world2cam=False):
    """
    accumulate pairwise relative poses into relative pose w.r.t first matrix
    """
    cam_exts = np.zeros((poses.shape[0] + 1, 4, 4))

    cur_pose = np.eye(4)
    cam_exts[0] = cur_pose
    for i in range(poses.shape[0]):
        if world2cam:
            cam_exts[i + 1, :] = cur_pose
        else:
            cam_exts[i + 1, :] = np.linalg.inv(cur_pose)
        cur_pose = poses[i] @ cur_pose
    return cam_exts


def check_pt_in_frame(center, img_width=640, img_height=480):
    return (
        center[0] > 0.0
        and center[0] < img_width
        and center[1] > 0.0
        and center[1] < img_height
    )


def check_box_in_frame(bbox, img_width=640, img_height=480):
    return (
        bbox[0] > 0.0 and bbox[1] > 0.0 and bbox[0] < img_width and bbox[1] < img_height
    ) and (
        bbox[2] > 0.0 and bbox[3] > 0.0 and bbox[2] < img_width and bbox[3] < img_height
    )


def check_pt_in_box(center, bbox):
    return (
        center[0] >= bbox[0]
        and center[0] <= bbox[2]
        and center[1] >= bbox[1]
        and center[1] <= bbox[3]
    )


def calc_metric_for_scale(  # noqa: C901
    scale,
    gt_tracklet,
    cam_int,
    poses,
    mask,
    depth,
    first_frame_ind,
    metric="mIOU",
    img_width=640,
    img_height=480,
    local_only=False,
):
    """
    Calculate mIOU on inferred tracklets from camera trajectory

    Args:
        scale: (float) normalization scale for depth, the parameter we are optimizing for
        gt_tracklet: (Nx4) numpy array of ground truth tracking bounding boxes
        cam_int: (3x3 matrix) camera intrinsics
        poses: (Nx4x4 matrices) series of poses relative to first frame
        mask: (HxW numpy 0-1 array) image masks for the frist frame
        depth: (HxW numpy array) image depth of the first frame (scaled vs. camera extrinsics)

    Returns:
        mIOU between inferred tracklet and ground truth tracklet
    """
    pose_np = poses

    if metric == "PointDis":
        pt = (gt_tracklet[first_frame_ind][:2] + gt_tracklet[first_frame_ind][2:]) / 2.0
        depth = scale[0]
        obj_centers = gen_center(cam_int, pose_np, first_frame_ind, pt, depth)
    elif metric == "mIOU":
        scaled_depth = depth / scale
        obj_centers, obj_vertices = gen_gt_tracklet(
            cam_int, pose_np, first_frame_ind, mask, scaled_depth
        )
        bounding_boxes = gen_bounding_box(obj_vertices)
    metric_val = 0.0
    metric_cnt = 0
    box_ind = 0
    fp = 0
    for img_ind in range(pose_np.shape[0]):
        pred_center = obj_centers[img_ind]
        if img_ind not in gt_tracklet:
            continue
        gt_bb = gt_tracklet[img_ind]
        if metric == "mIOU":
            pred_bb = bounding_boxes[box_ind]
            if pred_bb is None:
                continue
        if gt_bb is None or pred_center is None:
            continue
        metric_cnt += 1
        if metric == "mIOU":
            metric_val -= get_iou(gt_bb, pred_bb)
        elif metric == "PointDis":
            if check_pt_in_box(pred_center, gt_bb):
                # If within the bounding box, count distance as 0
                metric_val += 0.0
            else:
                # If not within the bounding box, count the distance to the bounding box
                x0 = min(abs(gt_bb[0] - pred_center[0]), abs(gt_bb[2] - pred_center[0]))
                y0 = min(abs(gt_bb[1] - pred_center[1]), abs(gt_bb[3] - pred_center[1]))
                x0 /= (img_width + img_height)
                y0 /= (img_width + img_height)
                x0 = 1.0 if x0 > 1.0 else x0
                y0 = 1.0 if y0 > 1.0 else y0
                if (pred_center[0] >= gt_bb[0]) and (pred_center[0] <= gt_bb[2]):
                    x0 = 0.0
                elif (pred_center[1] >= gt_bb[1]) and (pred_center[1] <= gt_bb[3]):
                    y0 = 0.0
                if (gt_bb[0] <= 0.0 or gt_bb[2] >= img_width - 1):
                    y0 = 0.0
                if (gt_bb[1] <= 0.0 or gt_bb[3] >= img_height - 1):
                    x0 = 0.0
                
                # upperlimit for penalty is 1.0
                metric_dis = 1.0 if (x0 + y0) > 1.0 else (x0 + y0)
                    
                metric_val += metric_dis
        box_ind += 1
    if metric_cnt == 0:
        return np.inf
    else:
        return metric_val / metric_cnt


def interpolate_pose(pose_converted, pose_num=None):
    max_ind = max(list(pose_converted.keys())) + 1
    min_ind = min(list(pose_converted.keys()))
    key_rots = []
    key_times = []
    interp_trans = []
    prev_ind = min_ind
    for ind in range(min_ind, max_ind):
        if ind in pose_converted:
            key_rots.append(pose_converted[ind][:3, :3])
            key_times.append(ind)
            cur_trans = pose_converted[ind][:3, 3]
            if ind - prev_ind > 1:
                prev_trans = pose_converted[prev_ind][:3, 3]
                for j in range(prev_ind + 1, ind):
                    interp_trans.append(
                        prev_trans
                        + (cur_trans - prev_trans) / (ind - prev_ind) * (j - prev_ind)
                    )
            interp_trans.append(cur_trans)
            prev_ind = ind
    slerp = Slerp(key_times, R.from_matrix(key_rots))

    times = list(range(min_ind, max_ind))

    interp_rots = slerp(times)
    interp_rots = interp_rots.as_matrix()
    interp_trans = np.array(interp_trans)
    interp_pose = np.zeros((pose_num, 4, 4))
    interp_pose[min_ind:max_ind, :3, :3] = interp_rots
    interp_pose[min_ind:max_ind, :3, 3] = interp_trans
    interp_pose[min_ind:max_ind, 3, 3] = 1.0
    interp_pose[:min_ind, :3, :3] = interp_rots[0]
    interp_pose[:min_ind, :3, 3] = interp_trans[0]
    interp_pose[:min_ind, 3, 3] = 1.0
    interp_pose[max_ind:, :3, :3] = interp_rots[-1]
    interp_pose[max_ind:, :3, 3] = interp_trans[-1]
    interp_pose[max_ind:, 3, 3] = 1.0
    print(max_ind, min_ind)
    print("Interpolated {} poses out of {}".format(max_ind - min_ind, pose_num))
    return interp_pose


def get_frame_num(dataset, scene_name, scenario="kitchen"):
    # hard code pose_num for sampled version
    pose_num = 1500
    return pose_num


def get_droid_slam_pose(scene_name, dataset='scannet', scenario="kitchen", world2cam=False, settings="", image_width=640, image_height=480):
    DROID_SLAM_PATH = "./sample_pose/droid_slam_sample.npy"
    droid_pose = np.load(DROID_SLAM_PATH.format(scene_name, settings))
    frame_num = get_frame_num(dataset, scene_name, scenario)

    pose_4x4 = {}
    ind = 0
    step = round(float(frame_num) / droid_pose.shape[0])
    for frame_ind in range(0, frame_num, step):
        rot_mat = R.from_quat(droid_pose[ind][3:]).as_matrix()
        cur_poses = np.eye(4)
        cur_poses[:3, :3] = rot_mat
        cur_poses[:3 , 3] = droid_pose[ind][:3]
        if world2cam:
            cur_poses = np.linalg.inv(cur_poses)
        pose_4x4[frame_ind] = cur_poses
        ind += 1
    return interpolate_pose(pose_4x4, pose_num=frame_num)


def get_orbslam2_pose(scene_name, dataset='scannet', scenario="kitchen", world2cam=False):
    frame_num = get_frame_num(dataset, scene_name, scenario)
    # hard code file location for supplemental
    txt_path = "./sample_pose/orbslam2_sample.txt"
    poses_dict = {}
    with open(txt_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip("\n").split(" ")
            k = int(float(line[0]))
            translation = np.array([float(line[i]) for i in range(1, 4)])
            quaternion = np.array([float(line[i]) for i in range(4, 8)])
            rotation = R.from_quat(quaternion).as_matrix()
            rt = np.hstack((rotation, translation[:, None]))
            pose = np.eye(4)
            pose[:3, :] = rt
            if world2cam:
                pose = np.linalg.inv(pose)
            poses_dict[k] = pose
    return interpolate_pose(poses_dict, pose_num=frame_num)


def get_orbslam3_pose(scene_name, dataset='scannet', scenario="kitchen", world2cam=False):
    frame_num = get_frame_num(dataset, scene_name, scenario)
    # hard code file location for supplemental
    txt_path = "./sample_pose/orbslam3_sample.txt"
    poses_dict = {}
    with open(txt_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip("\n").split(" ")
            k = int(float(line[0]))
            translation = np.array([float(line[i]) for i in range(1, 4)])
            quaternion = np.array([float(line[i]) for i in range(4, 8)])
            rotation = R.from_quat(quaternion).as_matrix()
            rt = np.hstack((rotation, translation[:, None]))
            pose = np.eye(4)
            pose[:3, :] = rt
            if world2cam:
                pose = np.linalg.inv(pose)
            poses_dict[k] = pose
    return interpolate_pose(poses_dict, pose_num=frame_num)


def get_colmap_pose(
    scene_name, dataset="scannet", world2cam=False, mask=False, scenario="kitchen"
):
    # hard code file location for supplemental
    with open(
        "./sample_pose/colmap_sample.json".format(
            scene_name,
        )
    ) as f:
        data = json.load(f)
    dataset = "ego4d"
    pose_converted = {}
    for k, v in data["poses"].items():
        pose = np.array(v)
        pose4x4 = np.eye(4)
        pose4x4[:3, :] = pose
        pose4x4[..., 1:3] *= -1
        pose4x4 = np.linalg.inv(pose4x4)
        pose_converted[int(k)] = pose4x4
    colmap_pose = interpolate_pose(
        pose_converted, pose_num=get_frame_num(dataset, scene_name, scenario)
    )
    if not world2cam:
        for ind in range(colmap_pose.shape[0]):
            colmap_pose[ind] = np.linalg.inv(colmap_pose[ind])
    return colmap_pose


def get_tartanvo_pose(scene_name, world2cam=False):
    # hard code file location for supplemental
    TARTANVO_PATH = "./sample_pose/tartan_sample.npy"
    tartan_pose = np.load(TARTANVO_PATH.format(scene_name))
    pose_4x4 = []
    for ind in range(tartan_pose.shape[0]):
        rot_mat = R.from_quat(tartan_pose[ind][3:]).as_matrix()
        cur_poses = np.eye(4)
        cur_poses[:3, :3] = rot_mat
        cur_poses[:3, 3] = tartan_pose[ind][:3]
        if world2cam:
            cur_poses = np.linalg.inv(cur_poses)
        pose_4x4.append(cur_poses)
    return np.array(pose_4x4)


def get_particlesfm_pose(
    scene_name, dataset="scannet", setting="30fps", world2cam=False, scenario="kitchen"
):
    # hard code file location for supplemental
    PARTICLE_SFM_PATH = "./sample_pose/particle_sfm_sample/"
    gt_poses = {}
    for file_name in os.listdir(PARTICLE_SFM_PATH):
        try:
            with open(
                os.path.join(
                    PARTICLE_SFM_PATH, file_name
                )
            ) as f:
                lines = [line.replace("\n", "").split(" ") for line in f.readlines()]
                lines = [list(map(float, i)) for i in lines]
        except Exception:
            continue
        m = re.match("([0-9])+", file_name)
        img_ind = int(m.group(0))
        mat = np.array(lines)
        pose = np.eye(4)
        pose[:3, :] = mat
        if not np.any(np.isnan(mat)):
            if world2cam:
                gt_poses[img_ind] = pose
            else:
                gt_poses[img_ind] = np.linalg.inv(pose)
    return interpolate_pose(
        gt_poses, pose_num=get_frame_num(dataset, scene_name, scenario)
    )

def get_pose(  # noqa: C901
    scene_name, model_name, dataset="scannet", world2cam=False, scenario="kitchen", args=None
):
    try:
        if model_name == "monodepth2":
            pred_poses = np.load(
                "./sample_pose/monodepth2_sample.npy"
            )
            pred_poses = accumulate_pairwise_pose(pred_poses, world2cam=world2cam)
        elif model_name == "colmap":
            pred_poses = get_colmap_pose(
                scene_name,
                dataset=dataset,
                world2cam=world2cam,
                mask=(model_name == "colmap_mask"),
                scenario=scenario,
            )
        elif model_name == "droid_slam":
            pred_poses = get_droid_slam_pose(scene_name, dataset=dataset, world2cam=world2cam, settings="", scenario=scenario)
        elif model_name == "orbslam2":
            pred_poses = get_orbslam2_pose(scene_name, dataset=dataset, world2cam=world2cam, scenario=scenario)
        elif model_name == "orbslam3":
            pred_poses = get_orbslam3_pose(scene_name, dataset=dataset, world2cam=world2cam, scenario=scenario)
        elif model_name == "particle_sfm":
            pred_poses = get_particlesfm_pose(
                scene_name, dataset=dataset, world2cam=world2cam, scenario=scenario
            )
        elif model_name == "tartanvo":
            pred_poses = get_tartanvo_pose(scene_name, world2cam=world2cam)
        else:
            return
    except Exception as e:
        print("Failed to load {} for {}".format(scene_name, model_name))
        print(e)
        raise e
        return
    pred_poses = np.array(pred_poses)
    return pred_poses
