import torch
from tqdm import tqdm, trange
import numpy as np
from .dvgo import get_rays_of_a_view
import os
import imageio
from .utils import to8b, rgb_lpips, rgb_ssim, gen_rand_colors, get_masked_img
import matplotlib.pyplot as plt
from .load_nvos import load_nvos_data
from PIL import Image


@torch.no_grad()
def render_viewpoints(model, render_poses, HW, Ks, ndc, render_kwargs,
                      gt_imgs=None, savedir=None, dump_images=False, cfg=None,
                      render_factor=0, render_video_flipy=False, render_video_rot90=0,
                      eval_ssim=False, eval_lpips_alex=False, eval_lpips_vgg=False, 
                      distill_active=False, seg_mask=True, render_fct=0.0):
    '''Render images for the given viewpoints; run evaluation if gt given.'''
    assert len(render_poses) == len(HW) and len(HW) == len(Ks)

    if render_factor!=0:
        HW = np.copy(HW)
        Ks = np.copy(Ks)
        HW = (HW/render_factor).astype(int)
        Ks[:, :2, :3] /= render_factor

    rgbs, features, segs, depths, bgmaps, psnrs, ssims, lpips_alex, lpips_vgg = [], [], [], [], [], [], [], [], []

    for i, c2w in enumerate(tqdm(render_poses)):
        H, W = HW[i]
        K = Ks[i]
        c2w = torch.Tensor(c2w)
        rays_o, rays_d, viewdirs = get_rays_of_a_view(
                H, W, K, c2w, ndc, inverse_y=render_kwargs['inverse_y'],
                flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
        keys = ['rgb_marched', 'depth', 'alphainv_last']
        if distill_active: keys.append('f_marched')
        if seg_mask: keys.append('seg_mask_marched')
        rays_o = rays_o.flatten(0,-2)
        rays_d = rays_d.flatten(0,-2)
        viewdirs = viewdirs.flatten(0,-2)
        render_result_chunks = [
            {k: v for k, v in model(ro, rd, vd, distill_active=distill_active, render_fct=render_fct, **render_kwargs).items() if k in keys}
            for ro, rd, vd in zip(rays_o.split(8192, 0), rays_d.split(8192, 0), viewdirs.split(8192, 0))
        ]
        render_result = {
            k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
            for k in render_result_chunks[0].keys()
        }
        
        rgb = render_result['rgb_marched'].cpu().numpy()
        if distill_active:
            feature = render_result['f_marched'].cpu()
        else:
            feature = None
            
        if seg_mask:
            seg_m = render_result['seg_mask_marched'].cpu()
        else:
            seg_m = None
            
        depth = render_result['depth'].cpu().numpy()
        bgmap = render_result['alphainv_last'].cpu().numpy()

        rgbs.append(rgb)
        if distill_active:
            features.append(feature)
            
        if seg_mask:
            segs.append(seg_m)
        depths.append(depth)
        bgmaps.append(bgmap)
        if i==0:
            print('Testing', rgb.shape)

        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
            psnrs.append(p)
            if eval_ssim:
                ssims.append(rgb_ssim(rgb, gt_imgs[i], max_val=1))
            if eval_lpips_alex:
                lpips_alex.append(rgb_lpips(rgb, gt_imgs[i], net_name='alex', device=c2w.device))
            if eval_lpips_vgg:
                lpips_vgg.append(rgb_lpips(rgb, gt_imgs[i], net_name='vgg', device=c2w.device))

    if len(psnrs):
        print('Testing psnr', np.mean(psnrs), '(avg)')
        if eval_ssim: print('Testing ssim', np.mean(ssims), '(avg)')
        if eval_lpips_vgg: print('Testing lpips (vgg)', np.mean(lpips_vgg), '(avg)')
        if eval_lpips_alex: print('Testing lpips (alex)', np.mean(lpips_alex), '(avg)')

    if render_video_flipy:
        for i in range(len(rgbs)):
            rgbs[i] = np.flip(rgbs[i], axis=0)
            depths[i] = np.flip(depths[i], axis=0)
            bgmaps[i] = np.flip(bgmaps[i], axis=0)
            segs[i] = np.flip(segs[i], axis=0)

    if render_video_rot90 != 0:
        for i in range(len(rgbs)):
            rgbs[i] = np.rot90(rgbs[i], k=render_video_rot90, axes=(0,1))
            depths[i] = np.rot90(depths[i], k=render_video_rot90, axes=(0,1))
            bgmaps[i] = np.rot90(bgmaps[i], k=render_video_rot90, axes=(0,1))
            segs[i] = np.rot90(segs[i], k=render_video_rot90, axes=(0,1))
            
    if savedir is not None and dump_images:
        for i in trange(len(rgbs)):
            rgb8 = to8b(rgbs[i])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)
        for i in trange(len(features)):
            filename = os.path.join(savedir, '{:03d}.pt'.format(i))
            feature = features[i].to(torch.device("cpu"))
            torch.save(feature, filename)

    rgbs = np.array(rgbs)
    depths = np.array(depths)
    bgmaps = np.array(bgmaps)
    # features = np.array(torch.stack(features))
    if len(features): features = np.stack(features)
    if len(segs): segs = np.stack(segs)

    return rgbs, depths, bgmaps, features, segs

@torch.no_grad()
def render_single_image(rvk, cfg, data_dict, idx, render_fct=0.0):
    rvk['model'] = rvk['model'].cuda()
    rgb, _, _, _, seg = render_viewpoints(
            render_poses=data_dict['poses'][idx:idx+1],
            HW=data_dict['HW'][idx:idx+1],
            Ks=data_dict['Ks'][idx:idx+1],
            cfg=cfg,
            render_fct=render_fct,
            **rvk)
    
    return rgb, seg

@torch.no_grad()
def render_images(render_viewpoints_kwargs: dict, cfg, data_dict: dict, render_fct=0.0):
    '''render all images from the dataset'''

    render_viewpoints_kwargs['model'] = render_viewpoints_kwargs['model'].cuda()
    rgbs, _, _, features, segs = render_viewpoints(
            render_poses=data_dict['poses'],
            HW=data_dict['HW'][data_dict['i_test']][[0]].repeat(len(data_dict['poses']), 0),
            Ks=data_dict['Ks'][data_dict['i_test']][[0]].repeat(len(data_dict['poses']), 0),
            cfg=cfg,
            render_fct=render_fct,
            **render_viewpoints_kwargs)
    
    return rgbs, segs

@torch.no_grad()
def get_masks(render_viewpoints_kwargs: dict, cfg, data_dict: dict, imageset: str, alpha_threshold):
    rgbs, _ = render_images(render_viewpoints_kwargs, cfg, data_dict, imageset, False, alpha_threshold)
    invalid_idx = rgbs.sum(-1) == (1 * 3)
    masks = np.ones_like(rgbs)
    masks[invalid_idx] = 0
    return masks

def render_opt_fn(render_type):
    '''switch between different render functions'''
    if render_type == 'train':
        return render_train
    elif render_type == 'test':
        return render_test
    elif render_type == 'segment':
        return render_segment
    elif render_type == 'video':
        return render_video
    elif render_type == 'nvos':
        return render_nvos
    elif render_type == 'spin':
        return render_images
    elif render_type == 'replica':
        return render_replica
    else:
        raise NotImplementedError

@torch.no_grad()
def render_train(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs, is_seged_rgb=False):
    rand_colors = gen_rand_colors(num_obj)
    testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_train_{ckpt_name}')
    os.makedirs(testsavedir, exist_ok=True)
    seg_img_dir=testsavedir
    if args.dump_images:
        if is_seged_rgb:
            seg_img_dir = os.path.join(testsavedir, 'seged_img')
            os.makedirs(seg_img_dir, exist_ok=True)
        else:
            seg_img_dir = os.path.join(testsavedir, 'full_img')
            os.makedirs(seg_img_dir, exist_ok=True)
    print('All results are dumped into', testsavedir)
    rgbs, depths, bgmaps, _, segs = render_viewpoints(
            render_poses=data_dict['poses'][data_dict['i_train']],
            HW=data_dict['HW'][data_dict['i_train']],
            Ks=data_dict['Ks'][data_dict['i_train']],
            gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_train']],
            cfg=cfg,savedir=seg_img_dir, dump_images=args.dump_images,
            eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
            distill_active=args.distill_active,
            **render_viewpoints_kwargs)
    # imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'.mp4'), to8b(rgbs), fps=30, quality=8)
    
    imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'.mp4'), to8b(rgbs), fps=30, quality=8)
    imageio.mimwrite(os.path.join(testsavedir, 'video.seg'+flag+e_flag+'.mp4'), to8b(segs>0), fps=30, quality=8)
    
    if not is_seged_rgb:
        seg_on_rgb = []
        masked_img_dir = os.path.join(testsavedir, 'masked_img')
        os.makedirs(masked_img_dir, exist_ok=True)
        for i, rgb, seg in zip(range(rgbs.shape[0]), rgbs, segs):
            # Winner takes all
            max_logit = np.expand_dims(np.max(seg, axis = -1), -1)
            tmp_seg = seg
            tmp_seg = np.argmax(tmp_seg, axis = -1)
            tmp_seg[max_logit[:,:,0] <= 0.1] = num_obj
            rendered_rgb = 0.3*rgb + 0.7*(rand_colors[tmp_seg])
            # seg_on_rgb = get_masked_img(rgb, rand_colors[tmp_seg])
            # valid = (seg[..., 0] > 0)
            # rendered_rgb = rgb.copy()
            # rendered_rgb[valid] = 0.2*rgb[valid] + 2*(rand_colors[tmp_seg])[valid]
            imageio.imwrite(os.path.join(masked_img_dir, 'rgb_{:03d}.png'.format(i)), to8b(rendered_rgb))
            seg_on_rgb.append(rendered_rgb)
        imageio.mimwrite(os.path.join(testsavedir, 'video.seg_on_rgb'+e_flag+'.mp4'), to8b(seg_on_rgb), fps=30, quality=8)
    imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'.mp4'), to8b(1 - depths / np.max(depths)), fps=30, quality=8)

    return rgbs, segs

@torch.no_grad()
def render_test(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs):
    rand_colors = gen_rand_colors(num_obj)
    testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_test_{ckpt_name}')
    os.makedirs(testsavedir, exist_ok=True)
    print('All results are dumped into', testsavedir)
    rgbs, depths, bgmaps, _, segs = render_viewpoints(
            render_poses=data_dict['poses'][data_dict['i_test']],
            HW=data_dict['HW'][data_dict['i_test']],
            Ks=data_dict['Ks'][data_dict['i_test']],
            cfg=cfg, gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test']],
            savedir=testsavedir, dump_images=args.dump_images,
            eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
            distill_active=args.distill_active,
            **render_viewpoints_kwargs)
    if e_flag == '':
        imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'.mp4'), to8b(rgbs), fps=30, quality=8)
    else:
        seg_on_rgb = []
        for rgb,seg in zip(rgbs, segs):
            # Winner takes all
            max_logit = np.expand_dims(np.max(seg, axis = -1), -1)
            print(max_logit.mean(), max_logit.std(), max_logit.min(), max_logit.max())
            tmp_seg = seg
            tmp_seg = np.argmax(tmp_seg, axis = -1)
            tmp_seg[max_logit[:,:,0] <= 0.1] = num_obj
            
            seg_on_rgb.append(0.3*rgb + 0.7*(rand_colors[tmp_seg]))
        imageio.mimwrite(os.path.join(testsavedir, 'video.seg_on_rgb'+e_flag+'.mp4'), to8b(seg_on_rgb), fps=30, quality=8)
    imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'.mp4'), to8b(1 - depths / np.max(depths)), fps=30, quality=8)

    return rgbs, segs

@torch.no_grad()
def render_segment(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs):
    rand_colors = gen_rand_colors(num_obj)
    testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_segment_{ckpt_name}')
    os.makedirs(testsavedir, exist_ok=True)
    print('All results are dumped into', testsavedir)
    stepsize = (len(data_dict['i_test']) // 3) + 1
    rgbs, depths, bgmaps, _, segs = render_viewpoints(
            render_poses=data_dict['poses'][data_dict['i_test'][::stepsize]],
            HW=data_dict['HW'][data_dict['i_test'][::stepsize]],
            Ks=data_dict['Ks'][data_dict['i_test'][::stepsize]],
            cfg=cfg, gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test'][::stepsize]],
            savedir=testsavedir, dump_images=args.dump_images,
            eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
            distill_active=args.distill_active,
            **render_viewpoints_kwargs)
    if e_flag == '':
        imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'.mp4'), to8b(rgbs), fps=30, quality=8)
    else:
        seg_on_rgb = []
        for rgb,seg in zip(rgbs, segs):
            # Winner takes all
            max_logit = np.expand_dims(np.max(seg, axis = -1), -1)
            tmp_seg = seg
            tmp_seg = np.argmax(tmp_seg, axis = -1)
            tmp_seg[max_logit[:,:,0] <= 0] = num_obj
            seg_on_rgb.append(0.3*rgb + 0.7*(rand_colors[tmp_seg]))
        imageio.mimwrite(os.path.join(testsavedir, 'video.seg_on_rgb'+e_flag+'.mp4'), to8b(seg_on_rgb), fps=30, quality=8)
    imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'.mp4'), to8b(1 - depths / np.max(depths)), fps=30, quality=8)

    return rgbs, segs

@torch.no_grad()
def render_video(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs, is_seged_rgb=False):
    rand_colors = gen_rand_colors(num_obj)
    testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_video_{ckpt_name}')
    os.makedirs(testsavedir, exist_ok=True)
    if args.dump_images:
        if is_seged_rgb:
            seg_img_dir = os.path.join(testsavedir, 'seged_img')
            os.makedirs(seg_img_dir, exist_ok=True)
        else:
            seg_img_dir = os.path.join(testsavedir, 'full_img')
            os.makedirs(seg_img_dir, exist_ok=True)
    print('All results are dumped into', testsavedir)
    rgbs, depths, bgmaps, _, segs = render_viewpoints(
            render_poses=data_dict['render_poses'],
            HW=data_dict['HW'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0),
            Ks=data_dict['Ks'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0),
            cfg=cfg,
            render_factor=args.render_video_factor,
            render_video_flipy=args.render_video_flipy,
            render_video_rot90=args.render_video_rot90,
            savedir=seg_img_dir, dump_images=args.dump_images,
            **render_viewpoints_kwargs)

    imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'.mp4'), to8b(rgbs), fps=30, quality=8)

    if not is_seged_rgb:
        seg_on_rgb = []
        masked_img_dir = os.path.join(testsavedir, 'masked_img')
        os.makedirs(masked_img_dir, exist_ok=True)
        for i, rgb, seg in zip(range(rgbs.shape[0]), rgbs, segs):
            # Winner takes all
            max_logit = np.expand_dims(np.max(seg, axis = -1), -1)
            tmp_seg = seg
            tmp_seg = np.argmax(tmp_seg, axis = -1)
            tmp_seg[max_logit[:,:,0] <= 0.1] = num_obj
            rendered_rgb = 0.3*rgb + 0.7*(rand_colors[tmp_seg])
            # seg_on_rgb = get_masked_img(rgb, rand_colors[tmp_seg])
            # valid = (seg[..., 0] > 0)
            # rendered_rgb = rgb.copy()
            # rendered_rgb[valid] = 0.2*rgb[valid] + 2*(rand_colors[tmp_seg])[valid]
            imageio.imwrite(os.path.join(masked_img_dir, 'rgb_{:03d}.png'.format(i)), to8b(rendered_rgb))
            seg_on_rgb.append(rendered_rgb)
        imageio.mimwrite(os.path.join(testsavedir, 'video.seg_on_rgb'+e_flag+'.mp4'), to8b(seg_on_rgb), fps=30, quality=8)
    
    if True:
        depths_vis = depths * (1-bgmaps) + bgmaps
        dmin, dmax = np.percentile(depths_vis[bgmaps < 0.1], q=[5, 95])
        depth_vis = plt.get_cmap('rainbow')(1 - np.clip((depths_vis - dmin) / (dmax - dmin), 0, 1)).squeeze()[..., :3]
        imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'.mp4'), to8b(depth_vis), fps=30, quality=8)

    return rgbs, segs

@torch.no_grad()
def render_nvos(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs):
    ref_ind, ref_pose, _, _, target_ind, target_pose, target_mask = load_nvos_data(cfg.data.datadir, cfg.data.factor)
    rgb, seg = render_single_image(render_viewpoints_kwargs, cfg, data_dict, idx = target_ind)
    plt.imshow(to8b(rgb[0]))
    plt.savefig('tmp_masks/rendered_target_rgb_nvos.jpg')
    plt.imshow(to8b(seg[0]))
    plt.savefig('tmp_masks/rendered_target_seg_nvos.jpg')
    h,w = seg[0].shape[0], seg[0].shape[1]
    target_mask = torch.from_numpy(np.array(target_mask)).unsqueeze(0).unsqueeze(0).float()
    target_mask = torch.nn.functional.interpolate(target_mask, (h,w), mode = 'nearest').squeeze(0).squeeze(0)
    print(target_mask.min(), target_mask.max())
    seg_res = torch.from_numpy(seg[0]).squeeze(-1)
    seg_res[seg_res <= 0] = 0
    seg_res[seg_res != 0] = 1
    print(seg_res.min(), seg_res.max())
    IoU = torch.count_nonzero((target_mask + seg_res == 2)) / torch.count_nonzero((target_mask + seg_res != 0))
    Acc = torch.count_nonzero((target_mask == seg_res)) / (h*w)
    print("nvos IoU is", IoU)
    print("nvos Acc is", Acc)
    return None, None


@torch.no_grad()
def render_replica(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs):
    ref_ind, ref_pose, target_ind, target_pose, target_mask = load_nvos_data(cfg.data.datadir, cfg.data.factor)
    rgb, seg = render_single_image(render_viewpoints_kwargs, cfg, data_dict, idx = target_ind)
    plt.imshow(to8b(rgb[0]))
    plt.savefig('tmp_masks/rendered_target_rgb_nvos.jpg')
    plt.imshow(to8b(seg[0]))
    plt.savefig('tmp_masks/rendered_target_seg_nvos.jpg')
    h,w = seg[0].shape[0], seg[0].shape[1]
    target_mask = torch.from_numpy(np.array(target_mask)).unsqueeze(0).unsqueeze(0).float()
    target_mask = torch.nn.functional.interpolate(target_mask, (h,w), mode = 'nearest').squeeze(0).squeeze(0)
    print(target_mask.min(), target_mask.max())
    seg_res = torch.from_numpy(seg[0]).squeeze(-1)
    seg_res[seg_res <= 0] = 0
    seg_res[seg_res != 0] = 1
    print(seg_res.min(), seg_res.max())
    IoU = torch.count_nonzero((target_mask + seg_res == 2)) / torch.count_nonzero((target_mask + seg_res != 0))
    Acc = torch.count_nonzero((target_mask == seg_res)) / (h*w)
    print("nvos IoU is", IoU)
    print("nvos Acc is", Acc)