import torch

@torch.compile()
def message_passing(X, edge_index):
    """
    Computes the message passing for a single graph,
    Y = AX
    where A is the normalized adjacency matrix
    Args:
        X: torch.Tensor, shape (N, D)
        edge_index: torch.Tensor, shape (2, E)
        edge_weight: torch.Tensor, shape (E,)
    Returns:
        Y: torch.Tensor, shape (N, D)
    """

    # compute message passing
    Y = torch.zeros_like(X)
    for i in range(edge_index.shape[1]):
        Y[edge_index[1,i]] += X[edge_index[0,i]]
    
    return Y

# @torch.compile()
def batch_message_passing(X, edge_index, edge_weight):
    """
    Computes the message passing for a batch of graphs,
    Y = AX
    where A is the normalized adjacency matrix
    Args:
        X: torch.Tensor, shape (B, N, D)
        edge_index: torch.Tensor, shape (3, E)
        edge_weight: torch.Tensor, shape (E,)
    Returns:
        Y: torch.Tensor, shape (B, N, D)
    """

    # compute message passing
    Y = torch.zeros_like(X)
    for i in range(edge_index.shape[1]):
        print(i)

        b, t, s = edge_index[:,i] # batch, target, source

        msg = X[b,s,:]*edge_weight[i]

        Y[b,t,:] += msg
    return Y



class ApproxSmoothingLayer(torch.nn.Module):
    def __init__(self, alpha=0.5, num_steps=1) -> None:
        super().__init__()
        self.alpha = alpha
        self.num_steps = num_steps

        if isinstance(self.alpha, float):
            self.coef = (1.0/(1.0-self.alpha)-1)
        elif self.alpha == 'trainable':
            self.coef = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        else:
            raise ValueError("alpha must be float or 'trainable'")
        
        # Sanity check for torch compile
        # X_random = torch.randn(10, 10, 10)
        # edge_index_random = torch.zeros((3, 10), dtype=torch.long)
        # edge_index_random[0, :] = torch.randint(0, 2, (10,))
        # edge_index_random[1, :] = torch.randint(0, 10, (10,))
        # edge_index_random[2, :] = torch.randint(0, 10, (10,))
        # edge_weight_random = torch.randn(10)
        # batch_message_passing(X_random, edge_index_random, edge_weight_random)
            
    def forward(self, f, A_mat):
        """
        input:
            f: tensor (batch_size, bag_size, d_dim)
            A_mat: sparse coo tensor (batch_size, bag_size, bag_size)
        output:
            g: tensor (batch_size, bag_size, d_dim)
        """
        # batch_size = f.shape[0]
        # bag_size = f.shape[1]

        # M_mat = torch.linalg.matrix_power(id_mat - self.coef*L, self.num_steps) # (batch_size, bag_size, bag_size)
        # g = torch.bmm(M_mat, f) # (batch_size, bag_size, 1)
        
        # lap_mat = lap_mat.to_dense() # (batch_size, bag_size, bag_size)
        # diag_lap_mat = torch.diagonal(lap_mat, dim1=1, dim2=2) # (batch_size, bag_size)
        # a_mat = -lap_mat # (batch_size, bag_size, bag_size)
        # diag_a_mat = 1 - diag_lap_mat # (batch_size, bag_size)
        # a_mat[:, torch.arange(bag_size), torch.arange(bag_size)] = diag_a_mat # (batch_size, bag_size, bag_size)
        
        # a_mat = -lap_mat
        # a_mat[:, torch.arange(bag_size), torch.arange(bag_size)] = 1 + lap_mat[:, torch.arange(bag_size), torch.arange(bag_size)]

        # id_mat = torch.eye(bag_size, device=lap_mat.device).unsqueeze(0).repeat(batch_size, 1, 1) # (batch_size, bag_size, bag_size)
        # a_mat = id_mat - lap_mat # (batch_size, bag_size, bag_size)

        # Pytorch bug: torch.bmm fails if d_dim = 1
        recover_f = False
        if f.shape[2] == 1:
            recover_f = True
            f = torch.stack([f, f], dim=2).squeeze(-1) # (batch_size, bag_size, 2)

        g = f
        alpha = 1.0 / (1.0 + self.coef)
        for _ in range(self.num_steps):
            g = (1.0 - alpha)*f + alpha*torch.bmm(A_mat, g) # (batch_size, bag_size, d_dim)
            # g = (1.0 - alpha)*f + alpha*torch.matmul(A_mat, g) # (batch_size, bag_size, d_dim)

        if recover_f:
            g = g[:, :, 0].unsqueeze(-1) # (batch_size, bag_size, 1)
        
        return g
    
    # def forward(self, f, lap_mat):
    #     """
    #     input:
    #         f: tensor (batch_size, bag_size, d_dim)
    #         lap_mat: sparse coo tensor (batch_size, bag_size, bag_size)
    #     output:
    #         g: tensor (batch_size, bag_size, d_dim)
    #     """

    #     g = f
    #     alpha = 1.0 / (1.0 + self.coef)
    #     lap_mat = lap_mat.coalesce()
    #     print(lap_mat.shape)
    #     for _ in range(self.num_steps):
    #         g = (1.0 - alpha)*f + alpha*batch_message_passing(g, lap_mat.indices(), lap_mat.values()) # (batch_size, bag_size, 1)
    #     return g

class ExactSmoothingLayer(torch.nn.Module):
    def __init__(self, alpha=0.5) -> None:
        super().__init__()
        self.alpha = alpha
        
        if isinstance(self.alpha, float):
            self.coef = (1.0/(1.0-self.alpha)-1)
        elif self.alpha == 'trainable':
            self.coef = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        else:
            raise ValueError("alpha must be float or 'trainable'")

    def forward(self, f, A_mat):
        """
        input:
            f: tensor (batch_size, bag_size, d_dim)
            lap_mat: sparse coo tensor (batch_size, bag_size, bag_size)
        output:
            g: tensor (batch_size, bag_size, d_dim)
        """
        batch_size = f.shape[0]
        bag_size = f.shape[1]

        id_mat = torch.eye(bag_size, device=A_mat.device).unsqueeze(0).repeat(batch_size, 1, 1) # (batch_size, bag_size, bag_size)

        M = (1+self.coef)*id_mat - self.coef*A_mat # (batch_size, bag_size, bag_size)
        g = self._solve_system(M, f) # (batch_size, bag_size, d_dim)
        return g

    def _solve_system(self, A, b):
        """
        input:
            A: tensor (batch_size, bag_size, bag_size)
            b: tensor (batch_size, bag_size, dim)
        output:
            x: tensor (batch_size, bag_size, d_dim)
        """
        x = torch.linalg.solve(A, b)
        #x = torch.linalg.lstsq(A, b).solution
        return x
