import torch
import torch.nn as nn
import numpy as np
import pdb, time, math
import functools
# from model.utils import get_sines

import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        query = self.query(x).squeeze(2)
        key = self.key(x).squeeze(2)
        value = self.value(x).squeeze(2)

        attention_scores = torch.matmul(query.transpose(1, 2), key)
        attention_scores = self.softmax(attention_scores)

        attended_features = torch.matmul(value, attention_scores.transpose(1, 2))
        attended_features = attended_features.mean(dim=-1)  # [1, C]

        return attended_features
class DGCNN(nn.Module):
    def __init__(self, input_channels, output_channels, dropout_prob=0.3):
        super(DGCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels * 2, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16 * 2, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(16 * 2, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(16 * 2, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(16*4, output_channels, kernel_size=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )

    def forward(self, x):
        x = get_graph_feature(x)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv5(x.unsqueeze(2))
        return x

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx

def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if num_points < 8:
            k = num_points
        elif num_points < 100:
            k = 8
        elif num_points < 500:
            k = 16
        else:
            k = 32
        idx = knn(x, k=k)
    device = x.device
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2)
    return feature

def get_activation(activation):
    if activation.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    elif activation.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation.lower() == 'softplus':
        return nn.Softplus()
    elif activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'mish':
        return nn.Mish(inplace=True)
    else:
        raise Exception("Activation Function Error")


def get_norm(norm, width):
    if norm == 'LN':
        return nn.LayerNorm(width)
    elif norm == 'BN':
        return nn.BatchNorm1d(width)
    elif norm == 'IN':
        return nn.InstanceNorm1d(width)
    elif norm == 'GN':
        return nn.GroupNorm(width)
    else:
        raise Exception("Normalization Layer Error")


class NeuralPCI_Layer(torch.nn.Module):
    def __init__(self, 
                 dim_in,
                 dim_out,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(dim_in, dim_out))
        if norm:
            layer_list.append(get_norm(norm, dim_out))
        if act_fn:
            layer_list.append(get_activation(act_fn))
            layer_list.append(nn.Dropout(0,3))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.layer(x)
        return x


class NeuralPCI_Block(torch.nn.Module):
    def __init__(self, 
                 depth, 
                 width,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        for _ in range(depth):
            layer_list.append(nn.Linear(width, width))
            if norm:
                layer_list.append(get_norm(norm, width))
            if act_fn:
                layer_list.append(get_activation(act_fn))
                layer_list.append(nn.Dropout(0,3))
        self.mlp = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.mlp(x)
        return x

class EdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels):
        super(EdgeConv, self).__init__()
        self.conv = nn.Conv1d(in_channels * 2 + time_channels * 2, out_channels, kernel_size=1)
        self.time_encoding = nn.Linear(1, time_channels)
        
    def forward(self, x, t):
        C, M = x.shape
        t_encode = self.time_encoding(t)
        t_encode = t_encode.transpose(1, 0)
        x_repeat = x.unsqueeze(2).repeat(1, 1, M)
        x_tile = x.unsqueeze(1).repeat(1, M, 1)
        t_repeat = t_encode.unsqueeze(2).repeat(1, 1, M)
        t_tile = t_encode.unsqueeze(1).repeat(1, M, 1)
        x_pair = torch.cat([x_repeat, x_tile, t_repeat, t_tile], dim=0)
        x_pair = x_pair.permute(2, 0, 1).contiguous()
        x_edge = self.conv(x_pair)
        x_edge = x_edge.permute(1, 2, 0).contiguous()
        x_edge = x_edge.max(dim=1, keepdim=False)[0]
        return x_edge


class GraphConvolutionalNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, time_steps):
        super(GraphConvolutionalNetwork, self).__init__()
        self.conv1 = EdgeConv(input_dim, input_dim, time_steps)
        self.conv2 = EdgeConv(input_dim, input_dim, time_steps)
        self.time_conv = nn.Conv1d(input_dim, input_dim, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(input_dim, output_dim, kernel_size=1)
        
    def forward(self, x, t):
        x = nn.functional.relu(self.conv1(x, t))
        x = nn.functional.relu(self.conv2(x, t))
        x = x.unsqueeze(0) 
        x = x.permute(2, 1, 0).contiguous() 
        x = nn.functional.relu(self.time_conv(x))
        x = x.permute(2, 1, 0).contiguous()
        x = x.squeeze(0)
        x = self.conv3(x)
        return x
class SimpleGraphConvolutionalNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, time_steps):
        super(SimpleGraphConvolutionalNetwork, self).__init__()
        self.conv1 = EdgeConv(input_dim, input_dim // 2, time_steps)
        self.conv2 = nn.Conv1d(input_dim // 2, output_dim, kernel_size=1)
        
    def forward(self, x, t):
        x = nn.functional.relu(self.conv1(x, t))
        x = self.conv2(x)
        return x

class GaussianMixtureModel(nn.Module):
    def __init__(self, in_channels, n_kernels):
        super().__init__()
        self.n_kernels = n_kernels
        self.mu = nn.Parameter(torch.randn(n_kernels, in_channels) / np.sqrt(n_kernels))
        self.sigma = nn.Parameter(torch.ones(n_kernels, in_channels))
        self.p = nn.Parameter(torch.ones(n_kernels) / n_kernels)

    def forward(self, x):
        x = x.unsqueeze(1) - self.mu.unsqueeze(0)
        z = torch.sum(x**2 / (self.sigma.unsqueeze(0)**2), dim=-1)
        p = torch.softmax(torch.log(self.p) - 0.5 * z, dim=-1)
        mu_x = torch.sum(p.unsqueeze(-1) * self.mu.unsqueeze(0), dim=1)
        return mu_x

class GaussianDeformationField(nn.Module):
    def __init__(self, num_gaussians, num_features):
        super(GaussianDeformationField, self).__init__()
        self.num_gaussians = num_gaussians
        self.num_features = num_features
        # self.gmm = GaussianMixtureModel(num_features, self.num_gaussians)
        
        self.motion_field_gcn = SimpleGraphConvolutionalNetwork(num_features+6+8, 3, 4)
        self.motion_features_gcn = GraphConvolutionalNetwork(num_features+6+8, num_features, 4)

        self.points_mlp = NeuralPCI_Layer(dim_in=6+8, 
                                          dim_out=3, 
                                          norm=None,
                                          act_fn=None  #'leaky_relu'
                                          )
        # self.feature_mlp = nn.Linear(num_features*2, num_features)
        self.feature_mlp = nn.Sequential(
            nn.Linear(num_features, num_features),
            NeuralPCI_Block(2, num_features, norm=None, act_fn='leakyrelu'),
            nn.Linear(num_features, num_features)
        )

        self.TimeExpansion = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU(),
            nn.Linear(8, 8),
            nn.Dropout(0.3)
        )

    def forward(self, means_t, covs_t, features_t, pc_current, diff_times):
        pc_current_exp = pc_current.unsqueeze(0).unsqueeze(1).expand(-1, self.num_gaussians, -1, -1)
        
        time_feat = self.TimeExpansion(diff_times)
        num_points = pc_current.shape[0]
        # Compute the scales of the gaussians from the covariance matrices
        scales_t = torch.sqrt(covs_t.diagonal(dim1=-2, dim2=-1))
        
        # Concatenate means_t, scales_t, and features_t
        # input_features = torch.cat([means_t, scales_t, features_t], dim=-1)
        features_mean = features_t.mean(dim=0, keepdim=True)
        features_std = features_t.std(dim=0, keepdim=True)
        features_t_normalized = (features_t - features_mean) / (features_std + 1e-6)
        means_mean = means_t.mean(dim=0, keepdim=True)
        means_std = means_t.std(dim=0, keepdim=True)
        means_t_normalized = (means_t - means_mean) / (means_std + 1e-6)

        scales_t = torch.sqrt(covs_t.diagonal(dim1=-2, dim2=-1))
        scales_mean = scales_t.mean(dim=0, keepdim=True)
        scales_std = scales_t.std(dim=0, keepdim=True)
        scales_t_normalized = (scales_t - scales_mean) / (scales_std + 1e-6)

        input_features = torch.cat([means_t_normalized, scales_t_normalized, features_t_normalized, time_feat], dim=-1)
        # Apply GCN to get per-gaussian motion field and features
        
        motion_field = self.motion_field_gcn(input_features.transpose(1, 0), diff_times).transpose(1, 0)
        motion_features = self.motion_features_gcn(input_features.transpose(1, 0), diff_times).transpose(1, 0)
        # motion_features = self.gmm(motion_features)

        gaussian_field_proj = motion_field.unsqueeze(0).unsqueeze(2).expand(-1, -1, num_points, -1)
        gaussian_feat_proj = motion_features.unsqueeze(0).unsqueeze(2).expand(-1, -1, num_points, -1)
        time_feat_proj = time_feat.unsqueeze(0).unsqueeze(2).expand(-1, -1, num_points, -1)
        fused_field = torch.cat([gaussian_field_proj, pc_current_exp, time_feat_proj], dim=-1)
        pooled_feat = torch.max(gaussian_feat_proj, dim=1)[0].squeeze()
        pooled_field = torch.max(fused_field, dim=1)[0].squeeze()
        
        motion_field_output = self.points_mlp(pooled_field)
        motion_features_output = self.feature_mlp(pooled_feat)
        return motion_field_output, motion_features_output
        # return motion_field, motion_features


class ST_GMM(nn.Module):
    def __init__(self, num_gaussians, num_rbf_centers, feature_dim):
        super(ST_GMM, self).__init__()
        self.num_gaussians = num_gaussians
        self.num_rbf_centers = num_rbf_centers
        self.feature_dim = feature_dim
        self.register_buffer('rbf_centers', torch.linspace(0, 1, num_rbf_centers))
        self.rbf_sigmas = nn.Parameter(torch.ones(num_rbf_centers))
        self.gmm_params = nn.Parameter(torch.randn(num_rbf_centers, num_gaussians, 12))
        self.feature_params = nn.Parameter(torch.randn(num_rbf_centers, num_gaussians, feature_dim))
        self.scale_params = nn.Parameter(torch.ones(num_gaussians))
        
        self.attention_mlp = nn.Sequential(
            nn.Linear(feature_dim, num_rbf_centers),
            nn.Softmax(dim=-1)
        )
        
    def rbf(self, x, c, s):
        return torch.exp(-((x - c) ** 2) / (2 * s ** 2))
    
    def forward(self, means, covs, features, time):
        # Compute RBF activations
        rbf_activations = self.rbf(time, self.rbf_centers, self.rbf_sigmas)
        rbf_activations = rbf_activations / rbf_activations.sum(dim=-1, keepdim=True)
        
        # Compute attention weights
        attention_weights = self.attention_mlp(features)
        
        # Apply attention weights to RBF activations
        rbf_activations = rbf_activations * attention_weights
        
        # Interpolate GMM parameters using attention-weighted RBF activations
        gmm_params_interp = torch.einsum('gc,cgp->gp', rbf_activations, self.gmm_params)
        feature_params_interp = torch.einsum('gc,cgf->gf', rbf_activations, self.feature_params)
        
        # Extract interpolated translation, rotation, and scale
        translation = gmm_params_interp[:, :3]
        rotation = gmm_params_interp[:, 3:12].view(self.num_gaussians, 3, 3)
        
        # Apply transformation to input means and covs
        means_t = means + translation
        covs_t = torch.matmul(torch.matmul(rotation, covs), rotation.transpose(-1, -2)) * self.scale_params.view(self.num_gaussians, 1, 1)
        
        # Apply transformation to input features
        features_t = features + feature_params_interp
        
        return means_t, covs_t, features_t

class AdaptiveAttentionFusion(nn.Module):
    def __init__(self, feature_dim, dim_4DGS, hidden_dim):
        super(AdaptiveAttentionFusion, self).__init__()
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        
        self.query_layer = nn.Linear(feature_dim, hidden_dim)
        self.key_layer = nn.Linear(dim_4DGS, hidden_dim)
        self.value_layer = nn.Linear(dim_4DGS, feature_dim)
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, feature1, feature2):
        query = self.query_layer(feature1)
        key = self.key_layer(feature2)
        value = self.value_layer(feature2)
        
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.hidden_dim)
        attention_weights = self.softmax(attention_scores)
        
        fused_features = torch.matmul(attention_weights, value)
        fused_features = fused_features + feature1
        
        return fused_features
    
class FourierFeatureEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FourierFeatureEncoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.frequencies = nn.Parameter(torch.randn(input_dim, output_dim // 2))
        
    def forward(self, x):
        x_proj = x.unsqueeze(-1) * self.frequencies
        sin_feat = torch.sin(2 * torch.pi * x_proj)
        cos_feat = torch.cos(2 * torch.pi * x_proj)
        fourier_feat = torch.cat([sin_feat, cos_feat], dim=-1)
        return fourier_feat.flatten(start_dim=1)

class NeuralPCI(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul
        dim_4DGS = args.dim_4DGS
        self.NeuralField = args.NeuralField
        self.Att_Fusion = args.Att_Fusion

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc
        self.args_n_gaussians = args.n_gaussians
        self.st_gmm = ST_GMM(self.args_n_gaussians, 2, dim_4DGS)  # Assume 2 frames: current and predicted
        self.gaussian_deformation_field = GaussianDeformationField(self.args_n_gaussians, dim_4DGS)

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in=(dim_rrf + dim_time) * pe_mul, 
                                           dim_out=layer_width, 
                                           norm=norm,
                                           act_fn=act_fn
                                           )
        if self.NeuralField:
            # hidden layers
            self.hidden_encode = NeuralPCI_Block(depth=depth_encode, 
                                                width=layer_width, 
                                                norm=norm,
                                                act_fn=act_fn
                                                )

            # insert interpolation time
            if self.Att_Fusion:
                self.layer_time = NeuralPCI_Layer(dim_in=layer_width + 3 + dim_time * pe_mul,
                                            dim_out=layer_width, 
                                            norm=norm,
                                            act_fn=act_fn)
            else:
                self.layer_time = NeuralPCI_Layer(dim_in=layer_width + 3 + dim_time * pe_mul + dim_4DGS,
                                                dim_out=layer_width, 
                                                norm=norm,
                                                act_fn=act_fn
                                                )
            # hidden layers
            self.hidden_pred = NeuralPCI_Block(depth=depth_pred, 
                                            width=layer_width, 
                                            norm=norm,
                                            act_fn=act_fn
                                            )
            self.layer_output = NeuralPCI_Layer(dim_in=layer_width, 
                                            dim_out=dim_pc, 
                                            norm=norm,
                                            act_fn=None
                                            )
        else:
            if self.Att_Fusion:
                # output layer
                self.layer_output = NeuralPCI_Layer(dim_in=layer_width + 3 + dim_time * pe_mul, 
                                                    dim_out=dim_pc, 
                                                    norm=norm,
                                                    act_fn=None
                                                    )
            else:
                self.layer_output = NeuralPCI_Layer(dim_in=layer_width + dim_4DGS + 3 + dim_time * pe_mul, 
                                                    dim_out=dim_pc, 
                                                    norm=norm,
                                                    act_fn=None
                                                    )
        if self.Att_Fusion:
            self.attention_Fu = AdaptiveAttentionFusion(layer_width, dim_4DGS, layer_width)
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    m.weight.data.zero_()
                    m.bias.data.zero_()
        
        self.dgcnn = DGCNN(input_channels=dim_pc, output_channels=dim_4DGS)
        self.SelfAttention = SelfAttention(dim_4DGS)
        self.cluster_centers = None
        self.point_assignment = None
        self.gaussians_mean = None
        self.gaussians_cov = None
        self.sorted_pc_current = None
    
    def posenc(self, x):
        """
        sinusoidal positional encoding : N ——> 3 * N
        [x] ——> [x, sin(x), cos(x)]
        """
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x

    # def posenc_pred(self, x, max_freq=3):
    #     sines = get_sines(x, max_freq)
    #     enc = torch.cat([x, sines], dim=-1)
    #     return enc

    def forward(self, pc_current, time_current, time_pred, train=True):
        time_pred_gs = torch.tensor(time_pred).repeat(self.args_n_gaussians, 1).cuda().float().detach()
        time_current_gs = torch.tensor(time_current).repeat(self.args_n_gaussians, 1).cuda().float().detach()
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).cuda().float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).cuda().float().detach()
        gaussians_mean, gaussians_cov, sorted_pc_current, gaussians_feature = self.gaussian_point_cloud(pc_current, self.args_n_gaussians)
        if self.args.use_rrf:
            sorted_pc_current = torch.matmul(2. * torch.pi * sorted_pc_current, self.transform)

        x = torch.cat((sorted_pc_current, time_current), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)
        if self.NeuralField:
            x = self.hidden_encode(x)
        time_pred = self.posenc(time_pred)
        means_t_pred, covs_t_pred, features_t_pred = self.st_gmm(gaussians_mean, gaussians_cov, gaussians_feature, time_pred_gs)
        motion_field, motion_features = self.gaussian_deformation_field(means_t_pred, covs_t_pred, features_t_pred, sorted_pc_current, time_pred_gs-time_current_gs)
        pc_pred = sorted_pc_current + motion_field
        if self.Att_Fusion:
            Fusion_motionX = self.attention_Fu(x, motion_features)
            x = torch.cat((Fusion_motionX, pc_pred, time_pred), dim=1)
        else:
            x = torch.cat((pc_pred, x, motion_features, time_pred), dim=1)
        if self.NeuralField:
            x = self.layer_time(x)
            x = self.hidden_pred(x)
        x = self.layer_output(x)
        flow = self.reorder_flow(sorted_pc_current, pc_current, x)
        return flow

    def gaussian_point_cloud(self, pc_current, num_gaussians, num_iters=800, epsilon=1e-5):
        if self.sorted_pc_current is None or not self.are_point_clouds_identical(pc_current, self.sorted_pc_current):
            N, _ = pc_current.shape
            self.cluster_centers = pc_current[torch.randperm(N)[:num_gaussians]]

            for _ in range(num_iters):
                distances = torch.cdist(pc_current, self.cluster_centers)
                probs = torch.softmax(-distances**2, dim=1)
                self.cluster_centers = torch.matmul(probs.t(), pc_current) / (probs.sum(dim=0).unsqueeze(1) + epsilon)
            _, self.point_assignment = torch.max(probs, dim=1)
            self.gaussians_mean = self.cluster_centers
            self.gaussians_cov = []
            for i in range(num_gaussians):
                pc_centered = pc_current[self.point_assignment == i] - self.gaussians_mean[i]
                pc_centered_squared = torch.matmul(pc_centered.unsqueeze(-1), pc_centered.unsqueeze(-2))
                cov = torch.sum(pc_centered_squared, dim=0) / (pc_centered.shape[0] + epsilon)
                cov += torch.eye(3, device=pc_current.device) * epsilon
                self.gaussians_cov.append(cov)

            gaussian_counts = torch.bincount(self.point_assignment, minlength=num_gaussians)
            empty_gaussians = (gaussian_counts == 0).nonzero(as_tuple=True)[0]
            while len(empty_gaussians) > 0:
                max_idx = gaussian_counts.argmax()
                max_gaussian_points = pc_current[self.point_assignment == max_idx]
                new_centers = max_gaussian_points[torch.randperm(max_gaussian_points.shape[0])[:2]]
                self.cluster_centers[empty_gaussians[0]] = new_centers[0]
                self.cluster_centers[max_idx] = new_centers[1]

                distances = torch.cdist(pc_current, self.cluster_centers)
                probs = torch.softmax(-distances**2, dim=1)
                _, self.point_assignment = torch.max(probs, dim=1)
                gaussian_counts = torch.bincount(self.point_assignment, minlength=num_gaussians)
                empty_gaussians = (gaussian_counts == 0).nonzero(as_tuple=True)[0]

            self.gaussians_cov = torch.stack(self.gaussians_cov, dim=0)
            sorted_indices = torch.argsort(self.point_assignment)
            self.sorted_pc_current = pc_current[sorted_indices]

        gaussians_feature = []
        for i in range(num_gaussians):
            local_points = self.sorted_pc_current[self.point_assignment == i]
            local_points2 = self.find_nearest_points(local_points, self.sorted_pc_current)
            local_points2 = local_points2.transpose(1, 0).unsqueeze(0).unsqueeze(-1)  # [1, num_channels, num_points, 1]
            local_feature = self.dgcnn(local_points2)
            att_local_feature = self.SelfAttention(local_feature)
            gaussians_feature.append(att_local_feature)
        gaussians_feature = torch.cat(gaussians_feature, dim=0)
        # pdb.set_trace()
        return self.gaussians_mean, self.gaussians_cov, self.sorted_pc_current, gaussians_feature

    def find_nearest_points(self, local_points, sorted_pc_current, k=8):
        if local_points.shape[0] < 5:
            diff = local_points.unsqueeze(1) - sorted_pc_current.unsqueeze(0)
            dist_sq = diff.pow(2).sum(-1)
            _, nearest_indices = dist_sq.topk(k, dim=1, largest=False)
            nearest_points = sorted_pc_current[nearest_indices]
            nearest_points = nearest_points.view(-1, 3)
            return nearest_points
        else:
            return local_points
    def are_point_clouds_identical(self, pc1, pc2):
        sorted_pc1 = pc1.sort(dim=0)[0]
        sorted_pc2 = pc2.sort(dim=0)[0]
        return torch.allclose(sorted_pc1, sorted_pc2)

    def reorder_flow(self, sorted_pc_current, pc_current, flow):
        distances = torch.cdist(pc_current, sorted_pc_current)
        _, indices = torch.min(distances, dim=1)
        reordered_flow = flow[indices]
        return reordered_flow