import numpy as np
from MeshUtils import *
import os
from PIL import Image
import BboxTools as bbt
from ProcessCameraParameters import get_anno, get_transformation_matrix
from scipy.linalg import logm
from pytorch3d.renderer import PerspectiveCameras


cate = 'motorbike'
device = 'cuda:0'

crop_size = {'car': (256, 672), 'bus': (320, 800), 'motorbike': (512, 512), 'boat': (384, 704), 'aeroplane': (288, 768), 'bicycle': (512, 512)}[cate]

render_image_size = (max(crop_size), max(crop_size))


# record_path = 'saved_records_unsup/kitti_20_fully_visible.npz'
# anno_path = '../KITTI/KITTI_train_distcrop/annotations/%s/' % cate
# image_path = '../KITTI/KITTI_train_distcrop/images/%s/' % cate
record_path = 'saved_records_unsup/PASCAL3D_%s_50.npz' % cate
anno_path = '../PASCAL3D/PASCAL3D_distcrop/annotations/%s/' % cate
image_path = '../PASCAL3D/PASCAL3D_distcrop/images/%s/' % cate

mesh_path = '../PASCAL3D/PASCAL3D+_release1.1/CAD/%s/' % cate

save_path = 'saved_visual_unsup/%s/' % record_path.split('/')[1].split('.')[0]


def rotation_theta(theta):
    # cos -sin  0
    # sin  cos  0
    # 0    0    1
    return np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]])


def cal_err(gt, pred):
    # return radius
    return ((logm(np.dot(np.transpose(pred), gt)) ** 2).sum()) ** 0.5 / (2. ** 0.5)


def cal_rotation_matrix(theta, elev, azum, dis):
    if dis <= 1e-10:
        dis = 0.5

    return rotation_theta(theta) @ get_transformation_matrix(azum, elev, dis)[0:3, 0:3]


def get_img(theta, elevation, azimuth, distance, this_meshes):
    theta = theta * torch.ones(1).to(device).view(1, 1)
    C = camera_position_from_spherical_angles(distance, elevation, azimuth, degrees=False, device=device)
    R, T = campos_to_R_T(C, theta, device=device)
    image = phong_renderer(meshes_world=this_meshes.clone(), R=R, T=T)
    image = image[:, ..., :3]
    box_ = bbt.box_by_shape(crop_size, (render_image_size[0] // 2,) * 2)
    bbox = box_.bbox
    image = image[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], :]
    image = torch.squeeze(image).detach().cpu().numpy()
    image = np.array((image / image.max()) * 255).astype(np.uint8)
    return image


if __name__ == '__main__':
    records = np.load(record_path, allow_pickle=True)

    os.makedirs(save_path, exist_ok=True)
    cameras = PerspectiveCameras(focal_length=1.0 * 3000, principal_point=((render_image_size[0]/ 2, render_image_size[1]/ 2),), image_size=(render_image_size, ), device=device)

    n_mesh = len(os.listdir(mesh_path))

    all_meshes = []
    for i in range(n_mesh):
        verts, faces = load_off(os.path.join(mesh_path, '%02d.off' % (i + 1)), to_torch=True)
        verts = pre_process_mesh_pascal(verts)

        verts_rgb = torch.ones_like(verts)[None] * torch.Tensor([1, 0.85, 0.85]).view(1, 1, 3)  # (1, V, 3)
        # textures = Textures(verts_rgb=verts_rgb.to(device))
        textures = Textures(verts_features=verts_rgb.to(device))
        meshes = Meshes(verts=[verts], faces=[faces], textures=textures)
        meshes = meshes.cuda()
        all_meshes.append(meshes)

    blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
    raster_settings = RasterizationSettings(
        image_size=render_image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
        bin_size=0
    )
    # We can add a point light in front of the object.
    lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
    phong_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=HardPhongShader(device=device, lights=lights, cameras=cameras),
    )

    for name_ in records.keys():
        theta_pred, distance_pred, elevation_pred, azimuth_pred = records[name_]
        fl_anno = np.load(os.path.join(anno_path, name_ + '.npz'))

        theta_anno, elevation_anno, azimuth_anno, distance_anno = get_anno(fl_anno, 'theta', 'elevation',
                                                                           'azimuth', 'distance')

        distance_pred += 0.7
        anno_matrix = cal_rotation_matrix(theta_anno, elevation_anno, azimuth_anno, distance_anno)
        pred_matrix = cal_rotation_matrix(theta_pred, elevation_pred, azimuth_pred, distance_pred)

        if np.any(np.isnan(anno_matrix)) or np.any(np.isnan(pred_matrix)) or np.any(np.isinf(anno_matrix)) or np.any(
                np.isinf(pred_matrix)):
            error_ = np.pi / 2
        error_ = cal_err(anno_matrix, pred_matrix)

        if error_ > np.pi / 18:
            continue

        cad_idx = fl_anno['cad_index']
        render_image = get_img(theta_pred, elevation_pred, azimuth_pred, distance_pred, all_meshes[cad_idx - 1])
        image = np.array(Image.open(os.path.join(image_path, name_ + '.JPEG')))

        final = np.concatenate((render_image, image), axis=1)
        Image.fromarray(final).save(os.path.join(save_path, name_ + '.jpg'))
        print('Finish: ', name_)



