import torch
import torch.nn as nn
import numpy as np
import pdb, time, math
import functools
from model.utils import get_sines

def get_activation(activation):
    if activation.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    elif activation.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation.lower() == 'softplus':
        return nn.Softplus()
    elif activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'mish':
        return nn.Mish(inplace=True)
    else:
        raise Exception("Activation Function Error")


def get_norm(norm, width):
    if norm == 'LN':
        return nn.LayerNorm(width)
    elif norm == 'BN':
        return nn.BatchNorm1d(width)
    elif norm == 'IN':
        return nn.InstanceNorm1d(width)
    elif norm == 'GN':
        return nn.GroupNorm(width)
    else:
        raise Exception("Normalization Layer Error")


class NeuralPCI_Layer(torch.nn.Module):
    def __init__(self, 
                 dim_in,
                 dim_out,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(dim_in, dim_out))
        if norm:
            layer_list.append(get_norm(norm, dim_out))
        if act_fn:
            layer_list.append(get_activation(act_fn))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.layer(x)
        return x


class NeuralPCI_Block(torch.nn.Module):
    def __init__(self, 
                 depth, 
                 width,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        for _ in range(depth):
            layer_list.append(nn.Linear(width, width))
            if norm:
                layer_list.append(get_norm(norm, width))
            if act_fn:
                layer_list.append(get_activation(act_fn))
        self.mlp = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.mlp(x)
        return x

class EdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__()
        self.conv = nn.Conv1d(in_channels * 2, out_channels, kernel_size=1)
        
    def forward(self, x):
        C, M = x.shape
        x_repeat = x.unsqueeze(2).repeat(1, 1, M)
        x_tile = x.unsqueeze(1).repeat(1, M, 1)
        x_pair = torch.cat([x_repeat, x_tile], dim=0)
        x_pair = x_pair.permute(2, 0, 1).contiguous()
        x_edge = self.conv(x_pair)
        x_edge = x_edge.permute(1, 2, 0).contiguous()
        x_edge = x_edge.max(dim=1, keepdim=False)[0]
        return x_edge


class GraphConvolutionalNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GraphConvolutionalNetwork, self).__init__()
        self.conv1 = EdgeConv(input_dim, input_dim)
        self.conv2 = EdgeConv(input_dim, input_dim)
        self.conv3 = nn.Conv1d(input_dim, output_dim, kernel_size=1)
        
    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = self.conv3(x)
        return x

class GaussianMixtureModel(nn.Module):
    def __init__(self, in_channels, n_kernels):
        super().__init__()
        self.n_kernels = n_kernels
        self.mu = nn.Parameter(torch.randn(n_kernels, in_channels) / np.sqrt(n_kernels))
        self.sigma = nn.Parameter(torch.ones(n_kernels, in_channels))
        self.p = nn.Parameter(torch.ones(n_kernels) / n_kernels)

    def forward(self, x):
        x = x.unsqueeze(1) - self.mu.unsqueeze(0)
        z = torch.sum(x**2 / (self.sigma.unsqueeze(0)**2), dim=-1)
        p = torch.softmax(torch.log(self.p) - 0.5 * z, dim=-1)
        mu_x = torch.sum(p.unsqueeze(-1) * self.mu.unsqueeze(0), dim=1)
        return mu_x

class GaussianDeformationField(nn.Module):
    def __init__(self, num_gaussians, num_features):
        super(GaussianDeformationField, self).__init__()
        self.num_gaussians = num_gaussians
        self.num_features = num_features
        # self.gmm = GaussianMixtureModel(num_features, self.num_gaussians)
        
        self.motion_field_gcn = GraphConvolutionalNetwork(num_features+6, 3)
        self.motion_features_gcn = GraphConvolutionalNetwork(num_features+6, num_features)

        self.points_mlp = NeuralPCI_Layer(dim_in=6, 
                                          dim_out=3, 
                                          norm=None,
                                          act_fn=None  #'leaky_relu'
                                          )
        # self.feature_mlp = nn.Linear(num_features*2, num_features)
        self.feature_mlp = nn.Sequential(
            nn.Linear(num_features*2, num_features),
            NeuralPCI_Block(3, num_features, norm=None, act_fn='leakyrelu'),
            nn.Linear(num_features, num_features)
        )

    def forward(self, means_t, covs_t, features_t, pc_current, x_feat):
        x_feat_exp = x_feat.unsqueeze(0).unsqueeze(1).expand(-1, self.num_gaussians, -1, -1)
        pc_current_exp = pc_current.unsqueeze(0).unsqueeze(1).expand(-1, self.num_gaussians, -1, -1)
        
        num_points = pc_current.shape[0]
        
        # # Reshape input tensors
        # means_t = means_t.view(self.num_gaussians, 3, 1)

        # diff = pc_current.unsqueeze(0) - means_t.transpose(1, 2)
        # variances_t_expanded = covs_t.diagonal(dim1=-2, dim2=-1).unsqueeze(1).expand(self.num_gaussians, num_points, 3)

        # reg_coef = 1e-6
        # variances_t_expanded_reg = variances_t_expanded + reg_coef

        # diff_squared = diff ** 2
        # mahalanobis_distances = torch.sum(diff_squared / variances_t_expanded_reg, dim=-1)
        # weights = torch.exp(-0.5 * mahalanobis_distances)  # (num_gaussians, num_points)

        # # Normalize weights
        # weights = weights / (torch.sum(weights, dim=0, keepdim=True) + 1e-6)

        # Compute the scales of the gaussians from the covariance matrices
        scales_t = torch.sqrt(covs_t.diagonal(dim1=-2, dim2=-1))
        
        # Concatenate means_t, scales_t, and features_t
        input_features = torch.cat([means_t, scales_t, features_t], dim=-1)

        # Apply GCN to get per-gaussian motion field and features
        motion_field = self.motion_field_gcn(input_features.transpose(1, 0)).transpose(1, 0)
        motion_features = self.motion_features_gcn(input_features.transpose(1, 0)).transpose(1, 0)
        # motion_features = self.gmm(motion_features)

        gaussian_field_proj = motion_field.unsqueeze(0).unsqueeze(2).expand(-1, -1, num_points, -1)
        gaussian_feat_proj = motion_features.unsqueeze(0).unsqueeze(2).expand(-1, -1, num_points, -1)
        fused_field = torch.cat([gaussian_field_proj, pc_current_exp], dim=-1)
        fused_feat = torch.cat([gaussian_feat_proj, x_feat_exp], dim=-1)
        pooled_feat = torch.max(fused_feat, dim=1)[0].squeeze()
        pooled_field = torch.max(fused_field, dim=1)[0].squeeze()
        
        motion_field_output = self.points_mlp(pooled_field)
        motion_features_output = self.feature_mlp(pooled_feat)

        return motion_field_output, motion_features_output
        # # Interpolate motion field and features
        
        
        # # motion_field_interpolated = torch.matmul(weights.transpose(0, 1), motion_field)
        # # motion_features_interpolated = torch.matmul(weights.transpose(0, 1), motion_features)
        # # Apply MLP to get per-point motion features
        # pdb.set_trace()

        # fused_feat = torch.cat([gaussian_feat_proj, x_feat_exp], dim=-1) 
        # motion_features_output = self.feature_mlp(motion_features_interpolated)
        # # if torch.any(torch.isnan(motion_features_output)):
        # #     print("mahalanobis_distances contains NaN values!")
        # #     pdb.set_trace()
        # return motion_field_interpolated, motion_features_output


class ST_GMM(nn.Module):
    def __init__(self, num_gaussians, num_rbf_centers, feature_dim):
        super(ST_GMM, self).__init__()
        self.num_gaussians = num_gaussians
        self.num_rbf_centers = num_rbf_centers
        self.feature_dim = feature_dim
        self.register_buffer('rbf_centers', torch.linspace(0, 1, num_rbf_centers))
        self.rbf_sigmas = nn.Parameter(torch.ones(num_rbf_centers))
        self.gmm_params = nn.Parameter(torch.randn(num_rbf_centers, num_gaussians, 12))
        self.feature_params = nn.Parameter(torch.randn(num_rbf_centers, num_gaussians, feature_dim))
        self.scale_params = nn.Parameter(torch.ones(num_gaussians))
        
        self.attention_mlp = nn.Sequential(
            nn.Linear(feature_dim, num_rbf_centers),
            nn.Softmax(dim=-1)
        )
        
    def rbf(self, x, c, s):
        return torch.exp(-((x - c) ** 2) / (2 * s ** 2))
    
    def forward(self, means, covs, features, time):
        # Compute RBF activations
        rbf_activations = self.rbf(time, self.rbf_centers, self.rbf_sigmas)
        rbf_activations = rbf_activations / rbf_activations.sum(dim=-1, keepdim=True)
        
        # Compute attention weights
        attention_weights = self.attention_mlp(features)
        
        # Apply attention weights to RBF activations
        rbf_activations = rbf_activations * attention_weights
        
        # Interpolate GMM parameters using attention-weighted RBF activations
        gmm_params_interp = torch.einsum('gc,cgp->gp', rbf_activations, self.gmm_params)
        feature_params_interp = torch.einsum('gc,cgf->gf', rbf_activations, self.feature_params)
        
        # Extract interpolated translation, rotation, and scale
        translation = gmm_params_interp[:, :3]
        rotation = gmm_params_interp[:, 3:12].view(self.num_gaussians, 3, 3)
        
        # Apply transformation to input means and covs
        means_t = means + translation
        covs_t = torch.matmul(torch.matmul(rotation, covs), rotation.transpose(-1, -2)) * self.scale_params.view(self.num_gaussians, 1, 1)
        
        # Apply transformation to input features
        features_t = features + feature_params_interp
        
        return means_t, covs_t, features_t

class AdaptiveAttentionFusion(nn.Module):
    def __init__(self, feature_dim, hidden_dim):
        super(AdaptiveAttentionFusion, self).__init__()
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        
        self.query_layer = nn.Linear(feature_dim, hidden_dim)
        self.key_layer = nn.Linear(feature_dim, hidden_dim)
        self.value_layer = nn.Linear(feature_dim, feature_dim)
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, feature1, feature2):
        query = self.query_layer(feature1)
        key = self.key_layer(feature2)
        value = self.value_layer(feature2)
        
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.hidden_dim)
        attention_weights = self.softmax(attention_scores)
        
        fused_features = torch.matmul(attention_weights, value)
        fused_features = fused_features + feature1
        
        return fused_features
    
class FourierFeatureEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FourierFeatureEncoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.frequencies = nn.Parameter(torch.randn(input_dim, output_dim // 2))
        
    def forward(self, x):
        x_proj = x.unsqueeze(-1) * self.frequencies
        sin_feat = torch.sin(2 * torch.pi * x_proj)
        cos_feat = torch.cos(2 * torch.pi * x_proj)
        fourier_feat = torch.cat([sin_feat, cos_feat], dim=-1)
        return fourier_feat.flatten(start_dim=1)

class NeuralPCI(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc
        self.args_n_gaussians = args.n_gaussians
        self.st_gmm = ST_GMM(self.args_n_gaussians, 7, layer_width)  # Assume 2 frames: current and predicted
        self.gaussian_deformation_field = GaussianDeformationField(self.args_n_gaussians, layer_width)

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in=(dim_rrf + dim_time) * pe_mul, 
                                           dim_out=layer_width, 
                                           norm=norm,
                                           act_fn=act_fn
                                           )
        # hidden layers
        self.hidden_encode = NeuralPCI_Block(depth=depth_encode, 
                                             width=layer_width, 
                                             norm=norm,
                                             act_fn=act_fn
                                             )

        # insert interpolation time
        self.layer_time = NeuralPCI_Layer(dim_in=layer_width+3 + dim_time * pe_mul, 
                                          dim_out=layer_width, 
                                          norm=norm,
                                          act_fn=act_fn
                                          )

        # hidden layers
        self.hidden_pred = NeuralPCI_Block(depth=depth_pred, 
                                           width=layer_width, 
                                           norm=norm,
                                           act_fn=act_fn
                                           )

        # output layer
        self.layer_output = NeuralPCI_Layer(dim_in=layer_width, 
                                            dim_out=dim_pc, 
                                            norm=norm,
                                            act_fn=None
                                            )
        
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    m.weight.data.zero_()
                    m.bias.data.zero_()

        # self.position_encoder = nn.Sequential(
        #     nn.Linear(3, 32),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(32, 32)
        #     )
        # self.attention = nn.TransformerEncoderLayer(d_model=32+layer_width*2, nhead=8)
        self.attention_Fu = AdaptiveAttentionFusion(layer_width, layer_width)
        # self.feature_transformer = nn.Linear(32 + layer_width * 2, layer_width)
    
    def posenc(self, x):
        """
        sinusoidal positional encoding : N ——> 3 * N
        [x] ——> [x, sin(x), cos(x)]
        """
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x

    # def posenc_pred(self, x, max_freq=3):
    #     sines = get_sines(x, max_freq)
    #     enc = torch.cat([x, sines], dim=-1)
    #     return enc

    def forward(self, pc_current, time_current, time_pred, train=True):
        time_pred_gs = torch.tensor(time_pred).repeat(self.args_n_gaussians, 1).cuda().float().detach()
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).cuda().float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).cuda().float().detach()
        
        if self.args.use_rrf:
            pc_current = torch.matmul(2. * torch.pi * pc_current, self.transform)

        x = torch.cat((pc_current, time_current), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)

        x = self.hidden_encode(x)
        time_pred = self.posenc(time_pred)

        means, covs, features = self.gaussian_point_cloud(pc_current, x, self.args_n_gaussians)
        # means_t_cur, covs_t_cur, features_t_cur = self.st_gmm(means, covs, features, time_current)
        means_t_pred, covs_t_pred, features_t_pred = self.st_gmm(means, covs, features, time_pred_gs)
        # zero_tensor = torch.zeros_like(pc_current)
        motion_field, motion_features = self.gaussian_deformation_field(means_t_pred, covs_t_pred, features_t_pred, pc_current, x)
        pc_pred = pc_current + motion_field
        # position_features = self.position_encoder(pc_pred)
        # position_features = self.position_encoder(motion_field)

        # x = torch.cat((x, motion_features, position_features), dim=1)
        # x = self.feature_transformer(x)
        # x = self.attention(x)

        Fusion_motionX = self.attention_Fu(x, motion_features)
        x = torch.cat((Fusion_motionX, pc_pred, time_pred), dim=1)
        # x = torch.cat((x, pc_pred, motion_features, time_pred), dim=1)

        x = self.layer_time(x)

        x = self.hidden_pred(x)

        x = self.layer_output(x)

        return x
    
    def gaussian_point_cloud(self, pc_current, features, num_gaussians, num_iters=10, epsilon=1e-5):

        N, _ = pc_current.shape
        _, C = features.shape
        
        cluster_centers = pc_current[torch.randperm(N)[:num_gaussians]]

        for _ in range(num_iters):
            distances = torch.cdist(pc_current, cluster_centers)

            probs = torch.softmax(-distances**2, dim=1)

            cluster_centers = torch.matmul(probs.t(), pc_current) / (probs.sum(dim=0).unsqueeze(1) + epsilon)

        gaussians_mean = cluster_centers

        pc_centered = pc_current.unsqueeze(1) - gaussians_mean.unsqueeze(0)

        pc_centered_squared = torch.matmul(pc_centered.unsqueeze(-1), pc_centered.unsqueeze(-2))
        gaussians_cov = torch.sum(pc_centered_squared * probs.unsqueeze(-1).unsqueeze(-1), dim=0) / (probs.sum(dim=0).unsqueeze(-1).unsqueeze(-1) + epsilon)

        gaussians_cov += torch.eye(3, device=pc_current.device) * epsilon

        gaussians_feature = torch.matmul(probs.t(), features) / (probs.sum(dim=0).unsqueeze(1) + epsilon) # (M, C)

        gaussians_mean = gaussians_mean# .unsqueeze(0)  # (1, M, 3)
        gaussians_cov = gaussians_cov #.unsqueeze(0)  # (1, M, 3, 3)

        return gaussians_mean, gaussians_cov, gaussians_feature
        
        