import os
import numpy as np
import tigre.algorithms as algs
import open3d as o3d
import sys
import argparse
import os.path as osp
import json
import pickle
from tqdm import trange
import copy

sys.path.append("./")
from utils.ct_utils import get_geometry, recon_volume
from utils.argument_utils import ParamGroup


class InitParams(ParamGroup):
    def __init__(self, parser):
        self.recon_method = "fdk"
        self.n_points = 50000
        self.density_thresh = 0.05
        self.density_rescale = 0.15
        super().__init__(parser, "Initialization Parameters")


def init_gaussian_from_volume(
    projs,
    angles,
    geo,
    recon_method,
    n_points,
    density_thresh,
    density_rescale,
):
    "Initialize gaussians from a volume."
    vol = recon_volume(projs, angles, copy.deepcopy(geo), recon_method)

    density_mask = vol > density_thresh
    valid_indices = np.argwhere(density_mask)
    offOrigin = geo.offOrigin
    dVoxel = geo.dVoxel
    sVoxel = geo.sVoxel

    assert (
        valid_indices.shape[0] >= n_points
    ), "Valid voxels less than target number of sampling. Check threshold"

    sampled_indices = valid_indices[
        np.random.choice(len(valid_indices), n_points, replace=False)
    ]
    sampled_positions = sampled_indices * dVoxel - sVoxel / 2 + offOrigin
    sampled_densities = vol[
        sampled_indices[:, 0],
        sampled_indices[:, 1],
        sampled_indices[:, 2],
    ]

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(sampled_positions)
    sampled_densities = sampled_densities * density_rescale
    pcd.colors = o3d.utility.Vector3dVector(np.stack([sampled_densities] * 3, -1))

    return pcd


def main(args, init_args: InitParams):
    data_path = args.data
    meta_data_path = osp.join(data_path, "meta_data.json")

    save_path = args.output
    os.makedirs(save_path, exist_ok=True)

    with open(meta_data_path, "r") as f:
        meta_data = json.load(f)

    projs_train = np.stack(
        [np.load(osp.join(data_path, m["file_path"])) for m in meta_data["proj_train"]],
        axis=0,
    )
    train_angles = np.stack([m["angle"] for m in meta_data["proj_train"]], axis=0)
    geo = get_geometry(meta_data["scanner"])

    init_pcd = init_gaussian_from_volume(
        projs=projs_train,
        angles=train_angles,
        geo=geo,
        recon_method=init_args.recon_method,
        n_points=init_args.n_points,
        density_thresh=init_args.density_thresh,
        density_rescale=init_args.density_rescale,
    )
    pcd_save_path = osp.join(
        save_path,
        f"init_{init_args.recon_method}_{init_args.n_points}.ply",
    )
    o3d.io.write_point_cloud(pcd_save_path, init_pcd)


if __name__ == "__main__":
    # fmt: off
    parser = argparse.ArgumentParser(description="Generate initialization parameters")
    init_parser = InitParams(parser)
    parser.add_argument("--data", default="data/cone_ntrain_50_angle_360/0_chest_cone", type=str, help="Path to data.")
    parser.add_argument("--output", default="data/cone_ntrain_50_angle_360/0_chest_cone/init", type=str, help="Path to output.")
    # fmt: on

    args = parser.parse_args()
    main(args, init_parser.extract(args))
