"""
This file contains the implementation of graph sparsification.
Reference:
[1] Spielman D A, Srivastava N. Graph sparsification by effective resistances[C]//Proceedings of the fortieth annual ACM symposium on Theory of computing. 2008: 563-568.
[2] Spielman D A, Teng S H. Nearly linear time algorithms for preconditioning and solving symmetric, diagonally dominant linear systems[J]. SIAM Journal on Matrix Analysis and Applications, 2014, 35(3): 835-885.
[3] Approximate Gaussian Elimination for Laplacians: Fast, Sparse, and Simple
[4] Improved Spectral Sparsiﬁcation and Numerical Algorithms for SDD Matrices
"""
import numpy as np
import scipy
import scipy.sparse as sp
import torch
import dgl


def _check_nan(A: np.ndarray):
    """
    Check if there is any NaN in the matrix A.
    """
    if np.isnan(A).any():
        return True
    return False


def pinv(L: np.ndarray):
    """
    Compute the pseudo-inverse of a matrix L.
    """
    U, S, VH = scipy.linalg.svd(L, lapack_driver='gesvd')
    S = np.diag(1 / S)
    return VH.T @ S @ U.T


def effective_resistance(L:np.ndarray):
    """
    Compute the effective resistance of a graph with Laplacian matrix L.
    The current version is brute force.
    """
    n = L.shape[0]
    Lp = np.linalg.pinv(L)
    print('Pinv of Lap matrix has nan:', _check_nan(Lp))
    # R = np.zeros((n, n))
    # for i in range(n):
    #     for j in range(n):
    #         R[i, j] = Lp[i, i] + Lp[j, j] - 2 * Lp[i, j]
    # Turn the above code into vectorized code
    # L is a symmetric matrix, so Lp is also symmetric.
    R = Lp.diagonal().reshape(-1, 1) + Lp.diagonal().reshape(1, -1) - 2 * Lp
    return R


def _sparsify_fully_connnected_graph(A: np.ndarray, q:int, load_er=None):
    """
    Sparsify a fully connected graph with adjacency matrix A by keeping q edges.
    This function is based on the paper [1]. We use numpy instead of torch because
    of the bugs in `torch.pinv`. The time complexity of this function is O(n^3), where
    n is the number of nodes in the graph.

    Args:
    A: np.ndarray, adjacency matrix of the fully connected graph.
    q: int, the number of edges to keep.
    """
    assert A.shape[0] == A.shape[1]
    n = A.shape[0]
    L = np.diag(A.sum(axis=1)) - A
    print('Lap matrix has nan:', _check_nan(L))
    if load_er is None:
        R = effective_resistance(L)
    else:
        R = np.load(f'./cache/sparsified_graph/{load_er}_sim_cosine_ER.npz')['ER']
    print('Effective Resistance matrix has nan:', _check_nan(R))
    sample_weight_raw = (R * A)
    # get the upper triangular part of matrix sample_weight
    triu_indices = np.triu_indices(n, k=1)
    sample_weight = sample_weight_raw[triu_indices]
    del sample_weight_raw, R
    prob = sample_weight / sample_weight.sum()
    np.savez(f'./cache/sparsified_graph/{load_er}_prob.npz', prob=prob)
    del sample_weight
    idx = np.random.choice(np.arange(prob.shape[0]), q, replace=True, p=prob)
    row_idx = triu_indices[0][idx]
    col_idx = triu_indices[1][idx]
    final_weight = (A.flatten()[idx] / (q * prob)[idx])

    # Sort the edges by the idx (row_idx, col_idx)
    idx = np.lexsort((row_idx, col_idx))
    row_idx = row_idx[idx]
    col_idx = col_idx[idx]
    final_weight = final_weight[idx]

    #TODO: sparse matrix coalesce required.
    # `row_idx` and `col_idx` are indices of the non-zero elements in the adjacency matrix.
    # `final_weight` is the weight of the non-zero elements.
    # However, there are multi-edges in the graph, so we need to coalesce the edges.
    # The coalescing process is to sum the weights of the multi-edges.
    adj = sp.coo_matrix((final_weight, (row_idx, col_idx)), shape=(n, n))
    adj.sum_duplicates()
    final_weight = adj.data
    row_idx = adj.row
    col_idx = adj.col
    return final_weight, row_idx, col_idx


def sparsify_fully_connnected_graph(A: np.ndarray, q:int, er=None) -> dgl.DGLGraph:
    """
    Wrapper of `sparsify_fully_connnected_graph` for torch.Tensor.
    """
    print('Adj matrix has nan:', _check_nan(A))
    weight, row, col = _sparsify_fully_connnected_graph(A, q, er)
    g = dgl.graph((row, col), num_nodes=A.shape[0], row_sorted=True, col_sorted=True)
    g.edata['weight'] = torch.tensor(weight, dtype=torch.float32).resize_(weight.shape[0], 1)
    return g
