import os
import sys
from typing import NamedTuple
import numpy as np
import os.path as osp
import json
from tqdm import trange
from plyfile import PlyData, PlyElement

from scene.gaussian_model import BasicPointCloud

mode_id = {
    "parallel": 0,
    "cone": 1,
}


class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int
    mode: int
    scanner_cfg: dict


class SceneInfo(NamedTuple):
    train_cameras: list
    test_cameras: list
    ply_path: str
    meta_data: dict


def readCTSceneInfo(path, eval, ply_path):
    meta_data_path = osp.join(path, "meta_data.json")
    with open(meta_data_path, "r") as handle:
        meta_data = json.load(handle)
    meta_data["vol"] = osp.join(path, meta_data["vol"])

    cam_infos = readCTameras(meta_data, path, eval)
    train_cam_infos = cam_infos["train"]
    test_cam_infos = cam_infos["test"]

    scene_info = SceneInfo(
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        ply_path=ply_path,
        meta_data=meta_data,
    )
    return scene_info


def readCTameras(meta_data, source_path, eval):
    """Read camera info."""
    cam_cfg = meta_data["scanner"]

    if eval:
        splits = ["train", "test"]
    else:
        splits = ["train"]

    cam_infos = {"train": [], "test": []}
    for split in splits:
        split_info = meta_data["proj_" + split]
        n_split = len(split_info)
        if split == "test":
            uid_offset = len(meta_data["proj_train"])
        else:
            uid_offset = 0
        for i_split in range(n_split):
            sys.stdout.write("\r")
            # the exact output you're looking for:
            sys.stdout.write(f"Reading camera {i_split + 1}/{n_split} for {split}")
            sys.stdout.flush()

            frame_info = meta_data["proj_" + split][i_split]
            frame_angle = frame_info["angle"]

            # CT 'transform_matrix' is a camera-to-world transform
            c2w = angle2pose(cam_cfg["DSO"], frame_angle)  # c2w
            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(
                w2c[:3, :3]
            )  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            image_path = osp.join(source_path, frame_info["file_path"])
            image = np.load(image_path)
            FovX = np.arctan2(cam_cfg["sDetector"][0] / 2, cam_cfg["DSD"]) * 2
            FovY = np.arctan2(cam_cfg["sDetector"][1] / 2, cam_cfg["DSD"]) * 2

            mode = mode_id[cam_cfg["mode"]]

            cam_info = CameraInfo(
                uid=i_split + uid_offset,
                R=R,
                T=T,
                FovY=FovY,
                FovX=FovX,
                image=image,
                image_path=image_path,
                image_name=osp.basename(image_path).split(".")[0],
                width=cam_cfg["nDetector"][0],
                height=cam_cfg["nDetector"][1],
                mode=mode,
                scanner_cfg=cam_cfg,
            )
            cam_infos[split].append(cam_info)
    sys.stdout.write("\n")
    return cam_infos


def angle2pose(DSO, angle):
    """Transfer angle to pose (c2w) based on scanner geometry.
    1. rotate -90 degree around x-axis (fixed axis),
    2. rotate 90 degree around z-axis  (fixed axis),
    3. rotate angle degree around z axis  (fixed axis)"""

    phi1 = -np.pi / 2
    R1 = np.array(
        [
            [1.0, 0.0, 0.0],
            [0.0, np.cos(phi1), -np.sin(phi1)],
            [0.0, np.sin(phi1), np.cos(phi1)],
        ]
    )
    phi2 = np.pi / 2
    R2 = np.array(
        [
            [np.cos(phi2), -np.sin(phi2), 0.0],
            [np.sin(phi2), np.cos(phi2), 0.0],
            [0.0, 0.0, 1.0],
        ]
    )
    R3 = np.array(
        [
            [np.cos(angle), -np.sin(angle), 0.0],
            [np.sin(angle), np.cos(angle), 0.0],
            [0.0, 0.0, 1.0],
        ]
    )
    rot = np.dot(np.dot(R3, R2), R1)
    trans = np.array([DSO * np.cos(angle), DSO * np.sin(angle), 0])
    transform = np.eye(4)
    transform[:3, :3] = rot
    transform[:3, 3] = trans

    # import open3d as o3d

    # world_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.0)
    # camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5)
    # bbox = o3d.geometry.LineSet.create_from_axis_aligned_bounding_box(
    #     o3d.geometry.AxisAlignedBoundingBox(
    #         min_bound=-np.ones((3, 1)), max_bound=np.ones((3, 1))
    #     )
    # )
    # bbox.colors = o3d.utility.Vector3dVector(
    #     np.zeros((np.asarray(bbox.lines).shape[0], 3))
    # )
    # camera_frame = camera_frame.transform(transform)
    # o3d.visualization.draw_geometries([world_frame, bbox, camera_frame])

    return transform


def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata["vertex"]
    positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
    colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
    normals = np.zeros_like(positions)
    return BasicPointCloud(points=positions, colors=colors, normals=normals)
