import cv2
import numpy as np

def check_pt_in_box(center, bbox):
    return (
        center[0] >= bbox[0]
        and center[0] <= bbox[2]
        and center[1] >= bbox[1]
        and center[1] <= bbox[3]
    )

def calc_pointdis(center, bb, img_width=640, img_height=480):
    if check_pt_in_box(center, bb):
        # If within the bounding box, count distance as 0
        metric_dis = 0.0
    else:
        # If not within the bounding box, count the distance to the bounding box
        x0 = min(abs(bb[0] - center[0]), abs(bb[2] - center[0]))
        y0 = min(abs(bb[1] - center[1]), abs(bb[3] - center[1]))
        x0 /= (img_width + img_height)
        y0 /= (img_width + img_height)
        x0 = 1.0 if x0 > 1.0 else x0
        y0 = 1.0 if y0 > 1.0 else y0
        if (center[0] >= bb[0]) and (center[0] <= bb[2]):
            x0 = 0.0
        elif (center[1] >= bb[1]) and (center[1] <= bb[3]):
            y0 = 0.0        
        # upperlimit for penalty is 1.0
        metric_dis = 1.0 if (x0 + y0) > 1.0 else (x0 + y0)
    return metric_dis


def get_intrinsics(calib):
    fx, fy, cx, cy = calib[:4]
    K = np.eye(4)
    K[0,0] = fx
    K[0,2] = cx
    K[1,1] = fy
    K[1,2] = cy
    return K


def undistort_img(img, calib):
    K = get_intrinsics(calib)[:3, :3]
    img = cv2.undistort(img, K[:3, :3], calib[4:])
    return img


def expand_bbox(box, ratio, img_width=640, img_height=480):
    box_center = np.array(
        [
            (box[0] + box[2]) / 2.0,
            (box[1] + box[3]) / 2.0,
        ]
    )
    expanded_box = np.array(
        [
            (box[0] - box_center[0]) * ratio + box_center[0],
            (box[1] - box_center[1]) * ratio + box_center[1],
            (box[2] - box_center[0]) * ratio + box_center[0],
            (box[3] - box_center[1]) * ratio + box_center[1],
        ]
    )
    expanded_box[0] = expanded_box[0] if expanded_box[0] >= 0.0 else 0.0
    expanded_box[1] = expanded_box[1] if expanded_box[1] >= 0.0 else 0.0
    expanded_box[2] = expanded_box[2] if expanded_box[2] <= img_width else img_width
    expanded_box[3] = expanded_box[3] if expanded_box[3] <= img_height else img_height
    return expanded_box


def gen_bounding_box(
    obj_vertices, min_percentile=10, max_percentile=90, img_width=640, img_height=480, first_frame_ind=0,
):
    """
    Generate bounding box given object vertices currently

    Args:
        obj_vertices: (list of Nx2 vertices) Projected vertices into image plane
        min_percentile: (float) minimum percentile for top left corner
        max_percentile: (float) maximum percentile for bottom right corner

    Returns:
        (dict of ind and pose pair) calculated bounding box given vertices
    """
    bbs = {}
    for ind, obj_vert in enumerate(obj_vertices):
        if obj_vert.shape[0] >= 30:
            corner_1 = np.percentile(obj_vert, 40, axis=0).astype(np.int32)
            corner_2 = np.percentile(obj_vert, 60, axis=0).astype(np.int32)
            bounding_box = np.concatenate((corner_1, corner_2), axis=0)
            # print(bounding_box)
            expanded_bounding_box = expand_bbox(
                bounding_box, (max_percentile - min_percentile) / 10.0,
                img_width=img_width, img_height=img_height
            )
            # print(expanded_bounding_box)
            if (
                bounding_box[0] >= 0.0
                and bounding_box[1] >= 0.0
                and bounding_box[0] <=img_width
                and bounding_box[1] <= img_height
            ) and (
                bounding_box[2] >= 0.0
                and bounding_box[3] >= 0.0
                and bounding_box[2] <= img_width
                and bounding_box[3] <= img_height
            ):
                if ind == first_frame_ind:
                    bbs[ind] = bounding_box
                else:
                    bbs[ind] = expanded_bounding_box
    return bbs


def get_random_color(ind):
    colors = [
        [0, 255, 0],
        [0, 0, 255],
        [255, 0, 0],
        [0, 255, 255],
        [255, 255, 0],
        [255, 0, 255],
        [80, 70, 180],
        [250, 80, 190],
        [245, 145, 50],
        [70, 150, 250],
        [50, 190, 190],
    ]
    random_color = colors[ind]
    return random_color


def random_color_masks(image, ind):
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    random_color = get_random_color(ind)
    r[image == 1], g[image == 1], b[image == 1] = random_color
    colored_mask = np.stack([r, g, b], axis=2)
    return colored_mask, random_color


def gen_segmentation_view(img, boxes, pred_cls, masks):
    for i in range(len(masks)):
        label = str(pred_cls[i])
        rgb_mask, random_color = random_color_masks(masks[i][0], hash(label) % 10)
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        pt1 = boxes[i][0]
        pt2 = boxes[i][1]
        cv2.rectangle(
            img,
            (int(pt1[0]), int(pt1[1])),
            (int(pt2[0]), int(pt2[1])),
            color=(0, 255, 0),
            thickness=1,
        )
        cv2.putText(
            img,
            label,
            (int(pt1[0]), int(pt1[1])),
            cv2.FONT_HERSHEY_COMPLEX,
            0.25,
            (0, 255, 0),
            thickness=1,
        )
    return img


def gen_frames(
    frames,
    poses,
    gt_bounding_box,
    obj_centers,
    base_img_path,
    scene_name,
    first_frame_ind,
    color=(255, 0, 0),
    calib=None,
    args=None,
):
    img_width = 640
    img_height = 480
    if frames is None:
        frames = {}
    for img_ind, obj_cent in zip(range(poses.shape[0]), obj_centers):
        if img_ind not in gt_bounding_box:
            continue
        gt_bb = gt_bounding_box[img_ind].astype(np.int32)
        if gt_bb is None:
            continue
        if img_ind in frames:
            img = frames[img_ind]
        else:
            if args is not None:
                img_path = base_img_path.format(args.scenario, scene_name, img_ind)
            elif scene_name == "demo":
                img_path = base_img_path.format(img_ind + 8154)
            else:
                img_path = base_img_path.format(scene_name, img_ind)
            img = cv2.imread(img_path)
            img = cv2.resize(img, (640, 480))
        try:
            cv2.circle(
                img,
                (int(obj_cent[0]), int(obj_cent[1])),
                radius=10,
                color=color,
                thickness=-1,
            )
        except Exception as e:
            print("Not able to plot", img_ind)
            raise e
            continue

        if gt_bb is not None:
            img = cv2.rectangle(img, tuple(gt_bb[:2]), tuple(gt_bb[2:]), color, 4)
        frames[img_ind] = np.uint8(img)
    return frames
