import functools
import math
import os
import time
from tkinter import W

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
from torch_scatter import segment_coo


from skimage.morphology import remove_small_holes
from cc3d import largest_k

parent_dir = os.path.dirname(os.path.abspath(__file__))
render_utils_cuda = load(
        name='render_utils_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']],
        verbose=True)

total_variation_cuda = load(
        name='total_variation_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/total_variation.cpp', 'cuda/total_variation_kernel.cu']],
        verbose=True)

class Deformation(nn.Module):
    def __init__(self, D=8, W=256, input_ch=27, input_ch_views=3, input_ch_time=9, skips=[],):
        super(Deformation, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.input_ch_time = input_ch_time
        self.skips = skips
        self._time, self._time_out = self.create_net()

    def create_net(self):
        layers = [nn.Linear(self.input_ch + self.input_ch_time, self.W)]
        for i in range(self.D - 2):
            layer = nn.Linear
            in_channels = self.W
            if i in self.skips:
                in_channels += self.input_ch
            layers += [layer(in_channels, self.W)]
        return nn.ModuleList(layers), nn.Linear(self.W, 3)

    def query_time(self, new_pts, t, net, net_final):
        h = torch.cat([new_pts, t], dim=-1)
        for i, l in enumerate(net):
            h = net[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([new_pts, h], -1)
        return net_final(h)

    def forward(self, input_pts, ts):
        dx = self.query_time(input_pts, ts, self._time, self._time_out)
        input_pts_orig = input_pts[:, :3]
        out=input_pts_orig + dx
        return out

# Model
class RGBNet(nn.Module):
    def __init__(self, D=3, W=256, h_ch=256, views_ch=33, pts_ch=27, times_ch=17, output_ch=3):
        """ 
        """
        super(RGBNet, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = h_ch
        self.input_ch_views = views_ch
        self.input_ch_pts = pts_ch
        self.input_ch_times = times_ch
        self.output_ch = output_ch
        self.feature_linears = nn.Linear(self.input_ch, W)
        self.views_linears = nn.Sequential(nn.Linear(W+self.input_ch_views, W//2),nn.ReLU(),nn.Linear(W//2, self.output_ch))
        
    def forward(self, input_h, input_views=None):
        feature = self.feature_linears(input_h)
        if input_views is not None:
            feature_views = torch.cat([feature, input_views],dim=-1)
        else:
            assert self.input_ch_views == 0
            feature_views = feature
        outputs = self.views_linears(feature_views)
        return outputs

'''Model'''
class TiNeuVox(torch.nn.Module):
    def __init__(self, xyz_min, xyz_max,
                 num_voxels=0, num_voxels_base=0, add_cam=False,
                 alpha_init=None, fast_color_thres=0,
                 voxel_dim=0, defor_depth=3, net_width=128,
                 posbase_pe=10, viewbase_pe=4, timebase_pe=8, gridbase_pe=2,
                 feat_only = False, no_view_dir = True,
                 **kwargs):
        
        super(TiNeuVox, self).__init__()
        self.add_cam = add_cam
        self.voxel_dim = voxel_dim
        self.defor_depth = defor_depth
        self.net_width = net_width
        self.feat_only = feat_only
        self.no_view_dir = no_view_dir
        self.posbase_pe = posbase_pe
        self.viewbase_pe = viewbase_pe
        self.timebase_pe = timebase_pe
        self.gridbase_pe = gridbase_pe
        times_ch = 2*timebase_pe+1
        if no_view_dir:
            views_ch = 0
        else:
            views_ch = 3+3*viewbase_pe*2
        pts_ch = 3+3*posbase_pe*2,
        self.register_buffer('xyz_min', torch.tensor(xyz_min).float())
        self.register_buffer('xyz_max', torch.tensor(xyz_max).float())
        self.fast_color_thres = fast_color_thres
        # determine based grid resolution
        self.num_voxels_base = num_voxels_base
        self.voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.num_voxels_base).pow(1/3)

        # determine the density bias shift
        self.alpha_init = alpha_init
        self.act_shift = np.log(1/(1-alpha_init) - 1)
        print('TiNeuVox: set density bias shift to', self.act_shift)

        timenet_width = net_width
        timenet_depth = 1
        timenet_output = voxel_dim+voxel_dim*2*gridbase_pe
        self.timenet = nn.Sequential(
        nn.Linear(times_ch, timenet_width), nn.ReLU(inplace=True),
        nn.Linear(timenet_width, timenet_output))
        if self.add_cam == True:
            views_ch = 3+3*viewbase_pe*2+timenet_output
            self.camnet = nn.Sequential(
            nn.Linear(times_ch, timenet_width), nn.ReLU(inplace=True),
            nn.Linear(timenet_width, timenet_output))
            print('TiNeuVox: camnet', self.camnet)

        featurenet_width = net_width
        featurenet_depth = 1
        grid_dim = voxel_dim*3+voxel_dim*3*2*gridbase_pe
        if self.feat_only:
            input_dim = grid_dim
        else:
            input_dim = grid_dim+timenet_output+0+0+3+3*posbase_pe*2
        self.featurenet = nn.Sequential(
            nn.Linear(input_dim, featurenet_width), nn.ReLU(inplace=True),
            *[
                nn.Sequential(nn.Linear(featurenet_width, featurenet_width), nn.ReLU(inplace=True))
                for _ in range(featurenet_depth-1)
            ],
            )
        self.featurenet_width = featurenet_width
        self._set_grid_resolution(num_voxels)
        self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=3+3*posbase_pe*2, input_ch_time=timenet_output)
        # self.forward_warp = Deformation(W=net_width, D=defor_depth, input_ch=3+3*posbase_pe*2, input_ch_time=timenet_output)
        input_dim = featurenet_width
        self.densitynet = nn.Linear(input_dim, 1)

        self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
        self.register_buffer('grid_poc', torch.FloatTensor([(2**i) for i in range(gridbase_pe)]))
        self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
        self.register_buffer('view_poc', torch.FloatTensor([(2**i) for i in range(viewbase_pe)]))

        self.voxel_dim = voxel_dim
        self.feature= torch.nn.Parameter(torch.zeros([1, self.voxel_dim, *self.world_size],dtype=torch.float32))
        self.rgbnet = RGBNet(W=net_width, h_ch=featurenet_width, views_ch=views_ch, pts_ch=pts_ch, times_ch=times_ch)

        print('TiNeuVox: feature voxel grid', self.feature.shape)
        print('TiNeuVox: timenet mlp', self.timenet)
        print('TiNeuVox: deformation_net mlp', self.deformation_net)
        print('TiNeuVox: densitynet mlp', self.densitynet)
        print('TiNeuVox: featurenet mlp', self.featurenet)
        print('TiNeuVox: rgbnet mlp', self.rgbnet)

    def _set_grid_resolution(self, num_voxels):
        # Determine grid resolution
        self.num_voxels = num_voxels
        self.voxel_size = ((self.xyz_max - self.xyz_min).prod() / num_voxels).pow(1/3)
        self.world_size = ((self.xyz_max - self.xyz_min) / self.voxel_size).long()
        self.voxel_size_ratio = self.voxel_size / self.voxel_size_base
        print('TiNeuVox: voxel_size      ', self.voxel_size)
        print('TiNeuVox: world_size      ', self.world_size)
        print('TiNeuVox: voxel_size_base ', self.voxel_size_base)
        print('TiNeuVox: voxel_size_ratio', self.voxel_size_ratio)

    def get_kwargs(self):
        return {
            'xyz_min': self.xyz_min.cpu().numpy(),
            'xyz_max': self.xyz_max.cpu().numpy(),
            'num_voxels': self.num_voxels,
            'num_voxels_base': self.num_voxels_base,
            'alpha_init': self.alpha_init,
            'act_shift': self.act_shift,
            'voxel_size_ratio': self.voxel_size_ratio,
            'fast_color_thres': self.fast_color_thres,
            'voxel_dim':self.voxel_dim,
            'defor_depth':self.defor_depth,
            'net_width':self.net_width,
            'posbase_pe':self.posbase_pe,
            'viewbase_pe':self.viewbase_pe,
            'timebase_pe':self.timebase_pe,
            'gridbase_pe':self.gridbase_pe,
            'add_cam': self.add_cam,
            'no_view_dir': self.no_view_dir
        }

    @torch.no_grad()
    def get_alpha_mask(self, stepsize, time_sel=torch.tensor([[0.,]]), viewdir=torch.tensor([[0.,0.,0.]]), threshold=None, sampling_freq=1):        
        grid_xyz = torch.stack(torch.meshgrid(
            torch.linspace(self.xyz_min[0], self.xyz_max[0], self.world_size[0] * sampling_freq),
            torch.linspace(self.xyz_min[1], self.xyz_max[1], self.world_size[1] * sampling_freq),
            torch.linspace(self.xyz_min[2], self.xyz_max[2], self.world_size[2] * sampling_freq),
        ), -1).view(-1, 3)

        time_sel = time_sel.to(self.time_poc.device)
        viewdir = viewdir.to(self.view_poc.device)

        times_emb = poc_fre(time_sel, self.time_poc)
        times_feature = self.timenet(times_emb).repeat(grid_xyz.shape[0], 1)

        rays_pts_emb = poc_fre(grid_xyz, self.pos_poc)
        ray_pts_delta = self.deformation_net(rays_pts_emb, times_feature)

        vox_feature_flatten = self.mult_dist_interp(ray_pts_delta)

        vox_feature_flatten_emb = poc_fre(vox_feature_flatten, self.grid_poc)
        if self.feat_only:
            h_feature = self.featurenet(vox_feature_flatten_emb)
        else:
            h_feature = self.featurenet(torch.cat((vox_feature_flatten_emb, rays_pts_emb, times_feature), -1))
        density_result = self.densitynet(h_feature)

        # alpha = nn.Softplus()(density_result + self.act_shift) # not bounded between 0 and 1?
        # alpha = alpha.squeeze(-1)
        interval = stepsize * self.voxel_size_ratio
        alpha = self.activate_density(density_result, interval)
        if threshold is None and self.fast_color_thres > 0:
            mask = alpha > self.fast_color_thres
        else:
            mask = alpha > threshold

        return grid_xyz.view(*self.world_size, 3), mask.view(*self.world_size)

    def get_grid_xyz(self, sampling_freq):
        
        return torch.stack(torch.meshgrid(
            torch.linspace(self.xyz_min[0], self.xyz_max[0], int(self.world_size[0] * sampling_freq)),
            torch.linspace(self.xyz_min[1], self.xyz_max[1], int(self.world_size[1] * sampling_freq)),
            torch.linspace(self.xyz_min[2], self.xyz_max[2], int(self.world_size[2] * sampling_freq)),
        ), -1)

    @torch.no_grad()
    def get_grid_as_point_cloud(self, stepsize, time_sel=torch.tensor([[0.,]]), viewdir=torch.tensor([[0.,0.,0.]]), cam_sel=None, threshold=None, canonical=False, sampling_freq=1, N_batch=2**20, blob_mask=None):        
        grid_xyz = self.get_grid_xyz(sampling_freq)
        og_shape = grid_xyz.shape
        grid_xyz = grid_xyz.view(-1, 3)
        if blob_mask is not None:
            blob_mask = blob_mask.view(-1)
    
        time_sel = time_sel.to(self.time_poc.device)
        viewdir = viewdir.to(self.view_poc.device)

        times_emb = poc_fre(time_sel, self.time_poc)
        viewdir_emb = poc_fre(viewdir, self.view_poc)

        batches = list(range(0, len(grid_xyz), N_batch))
        rest = len(grid_xyz) % N_batch
        if rest != 0:
            batches.append(batches[-1] + rest)

        alpha_all = [] 
        rgb_all = []
        h_feature_all = [] 
        vox_feature_flatten_all = []
        grid_xyz_all = []
        binary_volume = []

        for i in range(len(batches) - 1):
            batch_grid_xyz = grid_xyz[batches[i]:batches[i+1]]
            if blob_mask is not None:
                batch_blob_mask = blob_mask[batches[i]:batches[i+1]]
            viewdir_emb_batch = viewdir_emb

            times_feature = self.timenet(times_emb).repeat(batch_grid_xyz.shape[0], 1)
            if self.add_cam:
                cam_emb = poc_fre(cam_sel, self.time_poc)
                cam_feature = self.camnet(cam_emb)

            rays_pts_emb = poc_fre(batch_grid_xyz, self.pos_poc)
            if canonical:
                ray_pts_shifted = batch_grid_xyz
            else:
                ray_pts_shifted = self.deformation_net(rays_pts_emb, times_feature)
            
            vox_feature_flatten = self.mult_dist_interp(ray_pts_shifted)

            vox_feature_flatten_emb = poc_fre(vox_feature_flatten, self.grid_poc)
            if self.feat_only:
                h_feature = self.featurenet(vox_feature_flatten_emb)
            else:
                h_feature = self.featurenet(torch.cat((vox_feature_flatten_emb, rays_pts_emb, times_feature), -1))

            density_result = self.densitynet(h_feature)

            interval = stepsize * self.voxel_size_ratio
            alpha = self.activate_density(density_result, interval)
            batch_binary_volume = torch.zeros_like(alpha)

            if threshold is None and self.fast_color_thres > 0:
                mask = (alpha > self.fast_color_thres)
                if blob_mask is not None:
                    mask = mask & batch_blob_mask
                batch_binary_volume[mask] = True
                alpha = alpha[mask]
                h_feature = h_feature[mask]
                batch_grid_xyz = batch_grid_xyz[mask]
                vox_feature_flatten_emb = vox_feature_flatten_emb[mask]
                vox_feature_flatten = vox_feature_flatten[mask]
            else:
                mask = (alpha > threshold)
                if blob_mask is not None:
                    mask = mask & batch_blob_mask
                batch_binary_volume[mask] = True
                alpha = alpha[mask]
                h_feature = h_feature[mask]
                batch_grid_xyz = batch_grid_xyz[mask]
                vox_feature_flatten_emb = vox_feature_flatten_emb[mask]
                vox_feature_flatten = vox_feature_flatten[mask]

            if self.no_view_dir:
                rgb_logit = self.rgbnet(h_feature)
            else:
                if self.add_cam:
                    viewdir_emb_batch = torch.cat((viewdir_emb_batch, cam_feature), -1)

                viewdir_emb_batch = viewdir_emb_batch.repeat(h_feature.shape[0], 1)
                rgb_logit = self.rgbnet(h_feature, viewdir_emb_batch)
            rgb = torch.sigmoid(rgb_logit)

            alpha_all.append(alpha)
            rgb_all.append(rgb)
            h_feature_all.append(h_feature)
            vox_feature_flatten_all.append(vox_feature_flatten)
            grid_xyz_all.append(batch_grid_xyz)
            binary_volume.append(batch_binary_volume)

        alpha_all = torch.concat(alpha_all, dim=0)
        rgb_all = torch.concat(rgb_all, dim=0)
        h_feature_all = torch.concat(h_feature_all, dim=0)
        vox_feature_flatten_all = torch.concat(vox_feature_flatten_all, dim=0)
        grid_xyz_all = torch.concat(grid_xyz_all, dim=0)
        binary_volume = torch.concat(binary_volume, dim=0).view(*og_shape[:3])

        return grid_xyz_all, alpha_all, rgb_all, h_feature_all, vox_feature_flatten_all, binary_volume, grid_xyz.view(*og_shape)

    @torch.no_grad()
    def scale_volume_grid(self, num_voxels):
        print('TiNeuVox: scale_volume_grid start')
        ori_world_size = self.world_size
        self._set_grid_resolution(num_voxels)
        print('TiNeuVox: scale_volume_grid scale world_size from', ori_world_size, 'to', self.world_size)
        self.feature = torch.nn.Parameter(
            F.interpolate(self.feature.data, size=tuple(self.world_size), mode='trilinear', align_corners=True))
 
    def feature_total_variation_add_grad(self, weight, dense_mode):
        weight = weight * self.world_size.max() / 128
        total_variation_cuda.total_variation_add_grad(
            self.feature.float(), self.feature.grad.float(), weight, weight, weight, dense_mode)

    def grid_sampler(self, xyz, *grids, mode=None, align_corners=True):
        '''Wrapper for the interp operation'''
        mode = 'bilinear'
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1,1,1,-1,3)
        ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1
        ret_lst = [
            F.grid_sample(grid, ind_norm, mode=mode, align_corners=align_corners).reshape(grid.shape[1],-1).T.reshape(*shape,grid.shape[1])
            for grid in grids
        ]
        for i in range(len(grids)):
            if ret_lst[i].shape[-1] == 1:
                ret_lst[i] = ret_lst[i].squeeze(-1)
        if len(ret_lst) == 1:
            return ret_lst[0]
        return ret_lst
    
    def activate_density(self, density, interval=None):
        interval = interval if interval is not None else self.voxel_size_ratio
        shape = density.shape
        return Raw2Alpha.apply(density.flatten(), self.act_shift, interval).reshape(shape).squeeze(-1)

    def mult_dist_interp(self, ray_pts_delta):

        x_pad = math.ceil((self.feature.shape[2]-1)/4.0)*4-self.feature.shape[2]+1
        y_pad = math.ceil((self.feature.shape[3]-1)/4.0)*4-self.feature.shape[3]+1
        z_pad = math.ceil((self.feature.shape[4]-1)/4.0)*4-self.feature.shape[4]+1
        grid = F.pad(self.feature.float(),(0,z_pad,0,y_pad,0,x_pad))
        # three 
        vox_l = self.grid_sampler(ray_pts_delta, grid)
        vox_m = self.grid_sampler(ray_pts_delta, grid[:,:,::2,::2,::2])
        vox_s = self.grid_sampler(ray_pts_delta, grid[:,:,::4,::4,::4])
        vox_feature = torch.cat((vox_l,vox_m,vox_s),-1)

        if len(vox_feature.shape)==1:
            vox_feature_flatten = vox_feature.unsqueeze(0)
        else:
            vox_feature_flatten = vox_feature
        
        return vox_feature_flatten


    def get_mask(self, rays_o, rays_d, near, far, stepsize, **render_kwargs):
        '''Check whether the rays hit the geometry or not'''
        shape = rays_o.shape[:-1]
        rays_o = rays_o.reshape(-1, 3).contiguous()
        rays_d = rays_d.reshape(-1, 3).contiguous()
        stepdist = stepsize * self.voxel_size
        ray_pts, mask_outbbox, ray_id = render_utils_cuda.sample_pts_on_rays(
                rays_o, rays_d, self.xyz_min, self.xyz_max, near, far, stepdist)[:3]
        mask_inbbox = ~mask_outbbox
        hit = torch.zeros([len(rays_o)], dtype=torch.bool)
        hit[ray_id[mask_inbbox]] = 1
        return hit.reshape(shape)

    def sample_ray(self, rays_o, rays_d, near, far, stepsize, is_train=False, **render_kwargs):
        '''Sample query points on rays.
        All the output points are sorted from near to far.
        Input:
            rays_o, rayd_d:   both in [N, 3] indicating ray configurations.
            near, far:        the near and far distance of the rays.
            stepsize:         the number of voxels of each sample step.
        Output:
            ray_pts:          [M, 3] storing all the sampled points.
            ray_id:           [M]    the index of the ray of each point.
            step_id:          [M]    the i'th step on a ray of each point.
        '''
        rays_o = rays_o.contiguous()
        rays_d = rays_d.contiguous()
        stepdist = stepsize * self.voxel_size
        ray_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max = render_utils_cuda.sample_pts_on_rays(
            rays_o, rays_d, self.xyz_min, self.xyz_max, near, far, stepdist)
        mask_inbbox = ~mask_outbbox
        ray_pts = ray_pts[mask_inbbox]
        ray_id = ray_id[mask_inbbox]
        step_id = step_id[mask_inbbox]
        return ray_pts, ray_id, step_id,mask_inbbox

    def forward(self, rays_o, rays_d, viewdirs, times_sel, cam_sel=None, bg_points_sel=None, global_step=None, render_until=None, canonical_t=0, threshold=0.05, **render_kwargs):
        '''Volume rendering
        @rays_o:   [N, 3] the starting point of the N shooting rays.
        @rays_d:   [N, 3] the shooting direction of the N rays.
        @viewdirs: [N, 3] viewing direction to compute positional embedding for MLP.
        '''
        assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format'

        ret_dict = {}
        N = len(rays_o)
        times_emb = poc_fre(times_sel, self.time_poc)
        viewdirs_emb = poc_fre(viewdirs, self.view_poc)
        times_feature = self.timenet(times_emb)
        if self.add_cam:
            cam_emb= poc_fre(cam_sel, self.time_poc)
            cams_feature=self.camnet(cam_emb)
        # sample points on rays
        ray_pts, ray_id, step_id, mask_inbbox= self.sample_ray(
                rays_o=rays_o, rays_d=rays_d, is_train=global_step is not None, **render_kwargs)

        # pts deformation 
        rays_pts_emb = poc_fre(ray_pts, self.pos_poc)
        # Only warp non canonical points
        # time_mask = (times_sel[ray_id] != canonical_t).squeeze(-1)
        # ray_pts_delta = torch.zeros_like(ray_pts)
        # ray_pts_delta[time_mask] = ray_pts_delta[time_mask] + self.deformation_net(rays_pts_emb[time_mask], times_feature[ray_id][time_mask])
        # inv_time_mask = ~time_mask
        # ray_pts_delta[inv_time_mask] = ray_pts_delta[inv_time_mask] + ray_pts[inv_time_mask]

        ray_pts_delta = self.deformation_net(rays_pts_emb, times_feature[ray_id])

        if bg_points_sel is not None:
            bg_points_sel_emb = poc_fre(bg_points_sel, self.pos_poc)
            bg_points_sel_delta = self.deformation_net(bg_points_sel_emb, times_feature[:(bg_points_sel_emb.shape[0])])
            ret_dict.update({'bg_points_delta': bg_points_sel_delta})
        # voxel query interp
        vox_feature_flatten=self.mult_dist_interp(ray_pts_delta)

        times_feature = times_feature[ray_id]
        vox_feature_flatten_emb = poc_fre(vox_feature_flatten, self.grid_poc)
        if self.feat_only:
            h_feature = self.featurenet(vox_feature_flatten_emb)
        else:
            h_feature = self.featurenet(torch.cat((vox_feature_flatten_emb, rays_pts_emb, times_feature), -1))

        density_result = self.densitynet(h_feature)


        # forward_warping_mask = (density_result > threshold).squeeze(-1)
        # rays_pts_delta_emb = poc_fre(ray_pts_delta[forward_warping_mask], self.pos_poc)
        # ray_pts_hat = self.forward_warp(rays_pts_delta_emb.detach(), times_feature[ray_id][forward_warping_mask].detach())

        interval = render_kwargs['stepsize'] * self.voxel_size_ratio
        alpha = self.activate_density(density_result, interval)
        
        if self.fast_color_thres > 0:
            mask = (alpha > self.fast_color_thres)
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            alpha = alpha[mask]
            h_feature=h_feature[mask]

        # compute accumulated transmittance
        weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N)
        if self.fast_color_thres > 0:
            mask = (weights > self.fast_color_thres)
            weights = weights[mask]
            alpha = alpha[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            h_feature=h_feature[mask]

        if self.no_view_dir:
            rgb_logit = self.rgbnet(h_feature)
        else:
            viewdirs_emb_reshape = viewdirs_emb[ray_id]
            if self.add_cam:
                viewdirs_emb_reshape=torch.cat((viewdirs_emb_reshape, cams_feature[ray_id]), -1)
            rgb_logit = self.rgbnet(h_feature, viewdirs_emb_reshape)
        rgb = torch.sigmoid(rgb_logit)

        # Ray marching
        rgb_marched = segment_coo(
                src=(weights.unsqueeze(-1) * rgb),
                index=ray_id,
                out=torch.zeros([N, 3]),
                reduce='sum')
        rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg'])

        N_samples = int(np.linalg.norm(np.array(self.world_size.cpu())+1) /  render_kwargs['stepsize']) + 1
        s = (step_id+0.5) / N_samples

        ret_dict.update({
            'alphainv_last': alphainv_last,
            'weights': weights,
            'rgb_marched': rgb_marched,
            # 'ray_pts_hat': ray_pts_hat,
            'raw_alpha': alpha,
            'raw_rgb': rgb,
            'ray_id': ray_id,
            's': s,
            'n_max': N_samples,
            'ray_pts_delta': ray_pts_delta
        })

        with torch.no_grad():
            depth = segment_coo(
                    src=(weights * step_id),
                    index=ray_id,
                    out=torch.zeros([N]),
                    reduce='sum')
        ret_dict.update({'depth': depth})
        return ret_dict
    
    def ray_density(self, rays_o, rays_d, times_sel, cam_sel=None, bg_points_sel=None, global_step=None, threshold=0.1, **render_kwargs):
        '''Volume rendering
        @rays_o:   [N, 3] the starting point of the N shooting rays.
        @rays_d:   [N, 3] the shooting direction of the N rays.
        @viewdirs: [N, 3] viewing direction to compute positional embedding for MLP.
        '''
        assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format'

        ret_dict = {}
        N = len(rays_o)
        times_emb = poc_fre(times_sel, self.time_poc)
        times_feature = self.timenet(times_emb)

        # sample points on rays
        ray_pts, ray_id, step_id, mask_inbbox= self.sample_ray(
                rays_o=rays_o, rays_d=rays_d, is_train=global_step is not None, **render_kwargs)

        # pts deformation 
        rays_pts_emb = poc_fre(ray_pts, self.pos_poc)
        ray_pts_delta = ray_pts

        vox_feature_flatten = self.mult_dist_interp(ray_pts_delta)

        times_feature = times_feature[ray_id]
        vox_feature_flatten_emb = poc_fre(vox_feature_flatten, self.grid_poc)
        if self.feat_only:
            h_feature = self.featurenet(vox_feature_flatten_emb)
        else:
            h_feature = self.featurenet(torch.cat((vox_feature_flatten_emb, rays_pts_emb, times_feature), -1))

        density_result = self.densitynet(h_feature)
        interval = render_kwargs['stepsize'] * self.voxel_size_ratio
        alpha = self.activate_density(density_result, interval)
        
        if self.fast_color_thres > 0:
            mask = (alpha > self.fast_color_thres)
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            alpha = alpha[mask]
            h_feature = h_feature[mask]

        # compute accumulated transmittance
        weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N)
        if self.fast_color_thres > 0:
            mask = (weights > self.fast_color_thres)
            weights = weights[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]

        N_samples = int(np.linalg.norm(np.array(self.world_size.cpu())+1) /  render_kwargs['stepsize']) + 1
        s = (step_id+0.5) / N_samples

        ret_dict.update({
            'weights': weights,
            'ray_id': ray_id,
            's': s,
            'n_max': N_samples,
        })

        return ret_dict

class Alphas2Weights(torch.autograd.Function):
    @staticmethod
    def forward(ctx, alpha, ray_id, N):
        weights, T, alphainv_last, i_start, i_end = render_utils_cuda.alpha2weight(alpha, ray_id, N)
        if alpha.requires_grad:
            ctx.save_for_backward(alpha, weights, T, alphainv_last, i_start, i_end)
            ctx.n_rays = N
        return weights, alphainv_last

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_weights, grad_last):
        alpha, weights, T, alphainv_last, i_start, i_end = ctx.saved_tensors
        grad = render_utils_cuda.alpha2weight_backward(
                alpha, weights, T, alphainv_last,
                i_start, i_end, ctx.n_rays, grad_weights, grad_last)
        return grad, None, None


class Raw2Alpha(torch.autograd.Function):
    @staticmethod
    def forward(ctx, density, shift, interval):
        '''
        alpha = 1 - exp(-softplus(density + shift) * interval)
              = 1 - exp(-log(1 + exp(density + shift)) * interval)
              = 1 - exp(log(1 + exp(density + shift)) ^ (-interval))
              = 1 - (1 + exp(density + shift)) ^ (-interval)
        '''
        exp, alpha = render_utils_cuda.raw2alpha(density, shift, interval)
        if density.requires_grad:
            ctx.save_for_backward(exp)
            ctx.interval = interval
        return alpha

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_back):
        '''
        alpha' = interval * ((1 + exp(density + shift)) ^ (-interval-1)) * exp(density + shift)'
               = interval * ((1 + exp(density + shift)) ^ (-interval-1)) * exp(density + shift)
        '''
        exp = ctx.saved_tensors[0]
        interval = ctx.interval
        return render_utils_cuda.raw2alpha_backward(exp, grad_back.contiguous(), interval), None, None


''' Ray and batch
'''
def get_rays(H, W, K, c2w, inverse_y=False, flip_x=False, flip_y=False, mode='center'):
    i, j = torch.meshgrid(
        torch.linspace(0, W-1, W, device=c2w.device),
        torch.linspace(0, H-1, H, device=c2w.device))  # pytorch's meshgrid has indexing='ij'
    i = i.t().float()
    j = j.t().float()
    if mode == 'lefttop':
        pass
    elif mode == 'center':
        i, j = i+0.5, j+0.5
    elif mode == 'random':
        i = i+torch.rand_like(i)
        j = j+torch.rand_like(j)
    else:
        raise NotImplementedError

    if flip_x:
        i = i.flip((1,))
    if flip_y:
        j = j.flip((0,))
    if inverse_y:
        dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
    else:
        dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,3].expand(rays_d.shape)
    return rays_o, rays_d

def get_rays_np(H, W, K, c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3,3], np.shape(rays_d))
    return rays_o, rays_d

def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[...,2]) / rays_d[...,2]
    rays_o = rays_o + t[...,None] * rays_d

    # Projection
    o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
    o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
    o2 = 1. + 2. * near / rays_o[...,2]

    d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
    d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
    d2 = -2. * near / rays_o[...,2]

    rays_o = torch.stack([o0,o1,o2], -1)
    rays_d = torch.stack([d0,d1,d2], -1)

    return rays_o, rays_d

def get_rays_of_a_view(H, W, K, c2w, ndc, inverse_y=False, flip_x=False, flip_y=False, mode='center'):
    rays_o, rays_d = get_rays(H, W, K, c2w, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y, mode=mode)
    viewdirs = rays_d / rays_d.norm(dim=-1, keepdim=True)
    if ndc:
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
    return rays_o, rays_d, viewdirs

@torch.no_grad()
def get_training_rays(rgb_tr, times, train_poses, HW, Ks, ndc, inverse_y=False, flip_x=False, flip_y=False):
    print('get_training_rays: start')
    assert len(np.unique(HW, axis=0)) == 1
    # assert len(np.unique(Ks.reshape(len(Ks),-1), axis=0)) == 1
    assert len(rgb_tr) == len(train_poses) and len(rgb_tr) == len(Ks) and len(rgb_tr) == len(HW)
    H, W = HW[0]
    K = Ks[0]
    eps_time = time.time()
    rays_o_tr = torch.zeros([len(rgb_tr), H, W, 3], device=rgb_tr.device)
    rays_d_tr = torch.zeros([len(rgb_tr), H, W, 3], device=rgb_tr.device)
    viewdirs_tr = torch.zeros([len(rgb_tr), H, W, 3], device=rgb_tr.device)
    times_tr = torch.ones([len(rgb_tr), H, W, 1], device=rgb_tr.device)

    imsz = [1] * len(rgb_tr)
    for i, c2w in enumerate(train_poses):
        rays_o, rays_d, viewdirs = get_rays_of_a_view(
                H=H, W=W, K=K, c2w=c2w, ndc=ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)
        rays_o_tr[i].copy_(rays_o.to(rgb_tr.device))
        rays_d_tr[i].copy_(rays_d.to(rgb_tr.device))
        viewdirs_tr[i].copy_(viewdirs.to(rgb_tr.device))
        times_tr[i] = times_tr[i]*times[i]
        del rays_o, rays_d, viewdirs
    eps_time = time.time() - eps_time
    print('get_training_rays: finish (eps time:', eps_time, 'sec)')
    return rgb_tr, times_tr,rays_o_tr, rays_d_tr, viewdirs_tr, imsz


@torch.no_grad()
def get_training_rays_flatten(rgb_tr_ori, times,train_poses, HW, Ks, ndc, inverse_y=False, flip_x=False, flip_y=False):
    print('get_training_rays_flatten: start')
    assert len(rgb_tr_ori) == len(train_poses) and len(rgb_tr_ori) == len(Ks) and len(rgb_tr_ori) == len(HW)
    eps_time = time.time()
    DEVICE = rgb_tr_ori[0].device
    N = sum(im.shape[0] * im.shape[1] for im in rgb_tr_ori)
    rgb_tr = torch.zeros([N,3], device=DEVICE)
    rays_o_tr = torch.zeros_like(rgb_tr)
    rays_d_tr = torch.zeros_like(rgb_tr)
    viewdirs_tr = torch.zeros_like(rgb_tr)
    times_tr=torch.ones([N,1], device=DEVICE)
    times=times.unsqueeze(-1)
    imsz = []
    top = 0
    for c2w, img, (H, W), K ,time_one in zip(train_poses, rgb_tr_ori, HW, Ks,times):
        assert img.shape[:2] == (H, W)
        rays_o, rays_d, viewdirs = get_rays_of_a_view(
                H=H, W=W, K=K, c2w=c2w, ndc=ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)
        n = H * W
        times_tr[top:top+n]=times_tr[top:top+n]*time_one
        rgb_tr[top:top+n].copy_(img.flatten(0,1))
        rays_o_tr[top:top+n].copy_(rays_o.flatten(0,1).to(DEVICE))
        rays_d_tr[top:top+n].copy_(rays_d.flatten(0,1).to(DEVICE))
        viewdirs_tr[top:top+n].copy_(viewdirs.flatten(0,1).to(DEVICE))
        imsz.append(n)
        top += n
    assert top == N
    eps_time = time.time() - eps_time
    print('get_training_rays_flatten: finish (eps time:', eps_time, 'sec)')
    return rgb_tr, times_tr,rays_o_tr, rays_d_tr, viewdirs_tr, imsz


@torch.no_grad()
def get_training_rays_in_maskcache_sampling(rgb_tr_ori, masks_tr_ori, times, train_poses, HW, Ks, ndc, model, render_kwargs, img_to_cam, i_train, inverse_y=False, flip_x=False, flip_y=False, **kwargs):
    print('get_training_rays_in_maskcache_sampling: start')
    # assert len(rgb_tr_ori) == len(train_poses) and len(rgb_tr_ori) == len(Ks) and len(rgb_tr_ori) == len(HW)
    assert len(Ks) == len(train_poses) and len(rgb_tr_ori) == len(HW)

    CHUNK = 64
    DEVICE = rgb_tr_ori[0].device
    eps_time = time.time()
    N = sum(im.shape[0] * im.shape[1] for im in rgb_tr_ori) # N = len(i_train) * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2]
    rgb_tr = torch.zeros([N,3], device=DEVICE)
    masks_tr = torch.zeros([N, 1], device=DEVICE)
    pix_to_ray = torch.zeros([N], device=DEVICE, dtype=torch.int32)
    rays_o_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3).float()
    rays_d_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3).float()
    viewdirs_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3).float()
    index_to_times = {}
    times = times.unsqueeze(-1)
    masks = torch.zeros(Ks.shape[0], rgb_tr_ori.shape[1], rgb_tr_ori.shape[2], dtype=bool, device=DEVICE)
    # imsz = []
    

    H, W = HW[0]
    c = 0
    top = 0
    
    for c2w, K in zip(train_poses, Ks):
        rays_o, rays_d, viewdirs = get_rays_of_a_view(
                H=H, W=W, K=K, c2w=c2w.float(), ndc=ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)

        mask = torch.empty(H, W, device=DEVICE, dtype=torch.bool)
        for i in range(0, H, CHUNK):
            mask[i:i+CHUNK] = model.get_mask(
                    rays_o=rays_o[i:i+CHUNK], rays_d=rays_d[i:i+CHUNK], **render_kwargs).to(DEVICE)
        masks[c] = mask
        n = mask.sum()

        rays_o_tr[top:top+n].copy_(rays_o[mask].to(DEVICE))
        rays_d_tr[top:top+n].copy_(rays_d[mask].to(DEVICE))
        viewdirs_tr[top:top+n].copy_(viewdirs[mask].to(DEVICE))
        

        top += n
        c += 1

    top = torch.tensor(0)
    for c2w, img, mask_tr_ori, (H, W), K, time_one, PI_iter in zip(train_poses[img_to_cam], rgb_tr_ori, masks_tr_ori, HW, Ks[img_to_cam], times, img_to_cam):
        assert img.shape[:2] == (H, W)
        mask = masks[PI_iter]
        n = mask.sum()

        time_one = time_one.item()
        if time_one in index_to_times.keys():
            index_to_times[time_one] = (index_to_times[time_one][0], (top+n).item())
        else:
            index_to_times[time_one] = (top.item(), (top+n).item())

        rgb_tr[top:top+n].copy_(img[mask])
        masks_tr[top:top+n].copy_(mask_tr_ori[mask])
        pix_to_ray[top:top+n].copy_((PI_iter * H * W + torch.arange(0, H * W).reshape(H, W))[mask])
        
        # imsz.append(n)
        c += 1
        top += n

    print('get_training_rays_in_maskcache_sampling: ratio', top / N)
    rgb_tr = rgb_tr[:top]
    rays_o_tr = rays_o_tr[:top]
    rays_d_tr = rays_d_tr[:top]
    pix_to_ray = pix_to_ray[:top]
    viewdirs_tr = viewdirs_tr[:top]
    masks_tr = masks_tr[:top]
    eps_time = time.time() - eps_time
    print('get_training_rays_in_maskcache_sampling: finish (eps time:', eps_time, 'sec)')
    return rgb_tr, index_to_times, rays_o_tr, rays_d_tr, viewdirs_tr, pix_to_ray, masks_tr

@torch.no_grad()
def get_training_rays_in_maskcache_samplingTINEUVOX(rgb_tr_ori, masks_tr_ori, times,train_poses, HW, Ks, ndc, model, render_kwargs, img_to_cam, i_train, inverse_y=False, flip_x=False, flip_y=False, **kwargs):
    print('get_training_rays_in_maskcache_sampling: start')
    # assert len(rgb_tr_ori) == len(train_poses) and len(rgb_tr_ori) == len(Ks) and len(rgb_tr_ori) == len(HW)
    assert len(Ks) == len(train_poses) and len(rgb_tr_ori) == len(HW)

    CHUNK = 64
    DEVICE = rgb_tr_ori[0].device
    eps_time = time.time()
    N = sum(im.shape[0] * im.shape[1] for im in rgb_tr_ori) # N = len(i_train) * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2]
    rgb_tr = torch.zeros([N, 3], device=DEVICE)
    masks_tr = torch.zeros([N, 1], device=DEVICE)
    pix_to_ray = torch.zeros([N], device=DEVICE, dtype=torch.int32)
    rays_o_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3, device=DEVICE).float()
    rays_d_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3, device=DEVICE).float()
    viewdirs_tr = torch.zeros(Ks.shape[0] * rgb_tr_ori.shape[1] * rgb_tr_ori.shape[2], 3, device=DEVICE).float()
    # rays_o_tr = torch.zeros_like(rgb_tr)
    # rays_d_tr = torch.zeros_like(rgb_tr)
    # viewdirs_tr = torch.zeros_like(rgb_tr)
    times_tr = torch.ones([N,1], device=DEVICE)
    times = times.unsqueeze(-1)
    masks = torch.zeros(Ks.shape[0], rgb_tr_ori.shape[1], rgb_tr_ori.shape[2], dtype=bool, device=DEVICE)
    # imsz = []
    

    H, W = HW[0]
    c = 0
    top = 0
    
    for c2w, K in zip(train_poses, Ks):
        rays_o, rays_d, viewdirs = get_rays_of_a_view(
                H=H, W=W, K=K, c2w=c2w.float(), ndc=ndc, inverse_y=inverse_y, flip_x=flip_x, flip_y=flip_y)

        mask = torch.empty(H, W, device=DEVICE, dtype=torch.bool)
        for i in range(0, H, CHUNK):
            mask[i:i+CHUNK] = model.get_mask(
                    rays_o=rays_o[i:i+CHUNK], rays_d=rays_d[i:i+CHUNK], **render_kwargs).to(DEVICE)
        masks[c] = mask
        n = mask.sum()

        rays_o_tr[top:top+n].copy_(rays_o[mask].to(DEVICE))
        rays_d_tr[top:top+n].copy_(rays_d[mask].to(DEVICE))
        viewdirs_tr[top:top+n].copy_(viewdirs[mask].to(DEVICE))
        

        # times_tr[top:top+n]=times_tr[top:top+n]*time_one
        # rgb_tr[top:top+n].copy_(img[mask])
        
        # imsz.append(n)
        top += n
        c += 1

    top = torch.tensor(0)
    for c2w, img, mask_tr_ori, (H, W), K, time_one, PI_iter in zip(train_poses[img_to_cam], rgb_tr_ori, masks_tr_ori, HW, Ks[img_to_cam], times, img_to_cam):
        assert img.shape[:2] == (H, W)
        mask = masks[PI_iter]
        n = mask.sum()
        # index_to_times[top:top+n] = c

        # rgb_tr[top:top+n].copy_(img[mask])
        times_tr[top:top+n]=times_tr[top:top+n]*time_one
        rgb_tr[top:top+n].copy_(img[mask])
        masks_tr[top:top+n].copy_(mask_tr_ori[mask])
        pix_to_ray[top:top+n].copy_((PI_iter * H * W + torch.arange(0, H * W).reshape(H, W))[mask])
        
        # imsz.append(n)
        c += 1
        top += n

    print('get_training_rays_in_maskcache_sampling: ratio', top / N)
    rgb_tr = rgb_tr[:top]
    rays_o_tr = rays_o_tr[:top]
    rays_d_tr = rays_d_tr[:top]
    pix_to_ray = pix_to_ray[:top]
    viewdirs_tr = viewdirs_tr[:top]
    masks_tr = masks_tr[:top]
    eps_time = time.time() - eps_time
    print('get_training_rays_in_maskcache_sampling: finish (eps time:', eps_time, 'sec)')
    return rgb_tr, times_tr, rays_o_tr, rays_d_tr, viewdirs_tr, pix_to_ray, masks_tr


# def batch_indices_generator(N, BS):
#     # torch.randperm on cuda produce incorrect results in my machine
#     idx, top = torch.LongTensor(np.random.permutation(N)), 0
#     while True:
#         if top + BS > N:
#             idx, top = torch.LongTensor(np.random.permutation(N)), 0
#         yield idx[top:top+BS]
#         top += BS

def batch_indices_generator(N, BS):
    # More memory friendly
    rng = np.random.default_rng()
    while True:
        yield torch.LongTensor(rng.choice(N, size=BS, replace=False))


def poc_fre(input_data, poc_buf):

    input_data_emb = (input_data.unsqueeze(-1) * poc_buf).flatten(-2)
    input_data_sin = input_data_emb.sin()
    input_data_cos = input_data_emb.cos()
    input_data_emb = torch.cat([input_data, input_data_sin, input_data_cos], -1)
    return input_data_emb
