import numpy as np
import itertools
import networkx as nx
import scipy
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.cluster import KMeans


############################################################################################################################

def matrix_M(T,lamb,xi,c,phi,eta):
    
    ''' Function that computes the matrix M_T for a nodes covariance lambda, and edge weight xi. For temporal nodes
    both the node covariance and the edge weight is set equal to eta.
    
    Use : M_T = matrix_M(T,lamb,xi,c,phi,eta)
    
    Output : M_T (array of size 3Tx3T) : matrix M_T 
           
    Input  : T (scalar) : number of time frames
           : lamb (scalar) : value of lambda
           : xi (scalar) : weight of spatial edeges 
           : eta (scalar) : class label persistence
           : c (scalar) : expected average degree
           : phi (scalar) : second moment of the vector theta
   
    '''
    
    Mplus = np.array(([0,0,0],[0,0,0],[0,c*phi*lamb*xi,eta**2]))
    Mminus = np.array(([eta**2,c*phi*lamb*xi,0],[0,0,0],[0,0,0]))
    Mdiag = np.array(([0,0,0],[eta**2,c*phi*lamb*xi,eta**2],[0,0,0]))
    M = np.zeros((3*T,3*T))
    for i in range(1,T-1):
        M[3*i:3*(i+1):,3*i:3*(i+1)] = Mdiag
        M[3*i:3*(i+1):,3*(i+1):3*(i+2)] = Mplus
        M[3*(i+1):3*(i+2):,3*i:3*(i+1)] = Mminus

    M[:3,:3] = Mdiag
    M[:3,3:6] = Mplus
    M[3:6,:3] = Mminus
    M[3*(T-1):3*T,3*(T-1):3*T] = Mdiag

    return M

############################################################################################################################

def find_transition(T,eta):
    
    ''' This function computes the value of alpha_c(T,eta)
    Use: alpha = find_transition(T,eta)
    
    Ouput : alpha (scalar) : value of alpha_c(T,eta)
    Input : T (scalar) : number of frames considered
            eta (scalar) : label persistence
            
    '''
    
    ### We look for alpha_c through a dicotomy.
    
    alpha_min = 0 # initial left handside of the interval in which we look for alpha_c
    alpha_max = 1 # initial left handside of the interval in which we look for alpha_c

    err = 1 # initialization of the error
    
    while err > np.finfo(float).eps: # while the difference update is larger than mechine precision

        alpha = (alpha_min+alpha_max)/2 # mid point
        M_T = matrix_M(T,alpha,alpha,1,1,eta) # matrix M_T(alpha,eta)
        v = np.max(np.linalg.eigvals(M_T).real) # largest eigenvalue of M_T(alpha,eta)
        
        ### update the boundaries

        if v > 1:
            alpha_max = alpha
        else:
            alpha_min = alpha
            
        err = (alpha_max-alpha_min)/2


    return alpha    

############################################################################################################################

def matrix_C(c_out, c,fluctuation, fraction):
    
    ''' Function that generates the matrix C
    Use :
        C_matrix = matrix_C(c_out, c,fluctuation, fraction)
    Output:
        C_matrix (array of size k x k) : affinity matrix C
    Input:
        c_out (scalar) : average value of the of diagonal terms
        c (scalar) : average degree of the desired network
        fluctuation (scalar) : the off diagonal terms will be distributed according to N(c_out, c_out*fluctuation)
        fraction  (array of size equal to the number of clusters ) : vector \pi containing the  fraction of nodes in each class
        
    '''
    
    n_clusters = len(fraction) # number of clusters
    C_matrix = np.abs(np.random.normal(c_out, c_out*fluctuation, (n_clusters,n_clusters))) # generate the  off diagonal terms
    C_matrix = (C_matrix + C_matrix.T)/2 # symmetrize the  matrix
    nn = np.arange(n_clusters) 
    for i in range(n_clusters):
        x = nn[nn != i]
        C_matrix[i][i] = (c - (C_matrix[:,x]@fraction[x])[i])/fraction[i] # imposing CPi1 = c1

    return C_matrix  


############################################################################################################################


def label_function(eta, T, n_clusters, n, fraction):
    
    ''' Function that generates the sequence of label vectors
    Use : Label = label_function(eta, T, n_clusters, n, fraction)
    
    Output : Label (set of arrays) : Label[t] is the label vector at time t
    
    Input  : eta (scalar) : class label persistence
           : T (scalar) : number of time frames
           : n_clusters (scalar) : number of communities
           : n (scalar) : number of nodes 
           : fraction (array of size k) : fraction[i] is the fraction of nodes with label i
           
    '''
    
    label = np.zeros(n)
    for i in range(n_clusters):
        label[int(n*sum(fraction[:i])):int(n*sum(fraction[:i+1]))] = i
    label = label.astype(int) # initialization of the label at time zero
    Label = np.zeros((T,n))
    Label[0] = label
        
    
    for i in range(1,T):

        select = np.random.binomial(1,1-eta,n) # select which nodes will have their label reassigned
        draw = np.array([np.random.choice(np.arange(n_clusters), p=fraction) for i in range(np.sum(select))]) # assign random labels 
        label[select == 1] = draw # update the labels
        Label[i] = label # store the labels

        Label = Label.astype(int)
        
        
    return Label

############################################################################################################################

def adjacency_matrix_DCSBM(C_matrix,c, label, theta):
    
    ''' Function that generates the adjacency matrix A with n nodes and k communities
    Use:
        A = adjacency_matrix_DCSBM(C_matrix,c, label, theta)
        
    Output:
        A (sparse matrix of size n x n) : symmetric adjacency matrix
    
    Input:
        C_matrix (array of size k x k) : affinity matrix of the network C
        c (scalar) : average connectivity of the network
        label (array of size n) : vector containing the label of each node
        theta  (array of size n) : vector with the intrinsic probability connection of each node
    
    
    '''

    k = len(np.unique(label)) # number of communities
    fs = list()
    ss = list()

    n = len(theta)
    c_v = C_matrix[label].T # (k x n) matrix where we store the value of the affinity wrt a given label for each node
    first = np.random.choice(n,int(n*c),p = theta/n) # we choose the nodes that should get connected wp = theta_i/n: the number of times the node appears equals to the number of connection it will have

    for i in range(k): 
        v = theta*c_v[i]
        first_selected = first[label[first] == i] # among the nodes of first, select those with label i
        fs.append(first_selected.tolist())
        second_selected = np.random.choice(n,len(first_selected), p = v/np.sum(v)) # choose the nodes to connect to the first_selected
        ss.append(second_selected.tolist())

    fs = list(itertools.chain(*fs))
    ss = list(itertools.chain(*ss))

    fs = np.array(fs)
    ss  = np.array(ss)

    edge_list = np.column_stack((fs,ss)) # create the edge list from the connection defined earlier

    edge_list = np.unique(edge_list, axis = 0) # remove edges appearing more then once
    edge_list = edge_list[edge_list[:,0] > edge_list[:,1]] # keep only the edges such that A_{ij} = 1 and i > j

    G = nx.Graph()
    G.add_edges_from(edge_list)
    A = nx.adjacency_matrix(G, nodelist = np.arange(n)) # this creates a symmetric sparse matrix

    return A

############################################################################################################################


def adjacency_matrix_series_DDCSBM(T, C, c, eta, theta, fraction):
    
    ''' Function that generates the different realizations of the adjacency matrix
    
    Use : AT, Label = adjacency_matrix_series_DDCSBM(T, C, c, eta, theta,fraction)
    
    Output : AT (set of sparse arrays) : AT[t] is the sparse adjacency matrix of G_t
             Label (set of arrays) : Label[t] is the label vectos of G_t
             
    Input  : T (scalar) : number of time frames
           : C (array) : class affinity matrix
           : c (scalar) : ecpected average degree
           : eta (scalar) : label persistence
           : theta (set arrays) : theta[t] is the vector theta at time t
           : fraction (array of size k) : fraction[i] is the fraction of nodes with label i
    '''
    
    n = len(theta) # size of the network
    n_clusters = len(C) # number of clusters
    Label = label_function(eta, T, n_clusters, n, fraction) # Label[t] contains the labels at time t
    AT = [[]]*T
    for i in range(T):
        A = adjacency_matrix_DCSBM(C,c, Label[i], theta) # generation of the adjacency matrices and edge lists
        AT[i] = A
        
    return AT, Label


############################################################################################################################


def B_spec(AT,n,xi,h, n_eigs):
    
    ''' Function that computes the n_eigs eigrnvalues of the non-backtracking matrix of Equation with largest real part
    Use: v = B_spec(AT,n,xi,h, n_eigs)
    
    Output: v (complex array) : eigenvalues of B
    Input :  AT (set of arrays) : AT[t] is the adjacency matrix of G_t
          :  n (scalar) : size of the graph
          :  xi (scalar) : the weights among nodes at the same time are equal to xi
          :  eta (scalar) : the weights among nodes at different times are equal to h
          :  n_eigs ('all' or scalar) : if 'all' all the eigenvalues of B will be computed using its dense representation; otherwise a number equal to n_eigs of eigenvalues will be computed using the sparse representation
    '''


    I = scipy.sparse.diags(np.ones(n), offsets = 0)
    T = len(AT)
    DT = [scipy.sparse.diags(np.array(np.sum(AT[i], axis = 0))[0], offsets = 0) for i in range(T)]


    As = [[None]*T for i in range(T)]
    for i in range(T):
        for j in range(T):
            if i == j:
                As[i][j] = AT[i]
            else:
                As[i][j] = None

    As = scipy.sparse.bmat(As)

    Ds = [[None]*T for i in range(T)]
    for i in range(T):
        for j in range(T):
            if i == j:
                Ds[i][j] = DT[i]
            else:
                Ds[i][j] = None

    Ds = scipy.sparse.bmat(Ds)
    I = scipy.sparse.diags(np.ones(n), offsets = 0)

    At = [[None]*T for i in range(T)]
    for i in range(T):
        for j in range(T):
            if np.abs(i-j) == 1:
                At[i][j] = I

    At = scipy.sparse.bmat(At)
    Dt = scipy.sparse.diags(np.array(np.sum(At, axis = 0))[0], offsets = 0)

    IT = scipy.sparse.diags(np.ones(T*n), offsets = 0)


    B = scipy.sparse.bmat([[As*xi,-IT*xi,As*xi,None],
                          [(Ds-IT)*xi, None, Ds*xi, None],
                          [h*At, None, h*At, -h*IT],
                          [h*Dt, None, h*(Dt-IT), None]])


    if n_eigs == 'all':
        v = np.linalg.eigvals(B.A)
    else:
        v, X = scipy.sparse.linalg.eigs(B, k = n_eigs, which = 'LM')

    return v

############################################################################################################################


def BH_matrix(AT,n,xi,h):
    
    ''' This function constructs the Bethe-Hessian
    Use : H = BH_matrix(AT,n,xi,h)
    
    Output : H (sparse matrix) : Bethe-Hessian matrix in sparse representation
    Input  : AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
           : n (scalar) : size of the graph
           : xi (scalar) : parameter xi
           : h (scalar) : parameter h
           
    '''
    T = len(AT) # number of time frames
    I = scipy.sparse.diags(np.ones(n), offsets = 0) # sparse identity matrix
    dT = [np.array(np.sum(AT[i], axis = 0))[0] for i in range(T)] # set of degree vectors at all times
    DT = [scipy.sparse.diags(dT[i], offsets = 0) for i in range(T)] # set of degree matrices at all times
    HT = [(xi**2*DT[i]-xi*AT[i])/(1-xi**2)+(1+h**2)/(1-h**2)*I   for i in range(T)] # diagonal block
    HT[0] = (xi**2*DT[0]-xi*AT[0])/(1-xi**2)+ 1/(1-h**2)*I # boundary term
    HT[T-1] = (xi**2*DT[T-1]-xi*AT[T-1])/(1-xi**2)+ 1/(1-h**2)*I # boundary term
    Ms = [[None]*T for i in range(T)]
    for i in range(T-1):
        Ms[i][i] = HT[i]
        Ms[i][i+1] = -h/(1-h**2)*I # off diagonal block
        Ms[i+1][i] = -h/(1-h**2)*I

    Ms[T-1][T-1] = HT[T-1]

    H = scipy.sparse.bmat(Ms, format='csr')
    
    return H

############################################################################################################################


def overlap(real_classes, classes):

    '''Computes the overlap in neworks (with n nodes) with more then two classes and find the good permutation of the labels

    Use : 
        ov = overlap(real_classees, classees)
        
    Output : 
        ov (scalar) : overlap
    
    Input : 
        real_classes (array of size n) : vector with the true labels
        classes (array of size n) : vector of the estimated labels
    
    '''
    values = max(len(np.unique(real_classes)),len(np.unique(classes))) # number of classes
    n = len(classes) # size of the network

    matrix = np.zeros((values,values))
    for i in range(n):
        matrix[classes[i]][real_classes[i]] += 1 # n_classes x n_classes confusion matrix. Each entry corresponds to how many time label i and label j appeared assigned to the same node

    positions = np.zeros(values)
    for i in range(values):
        positions[i] = np.argmax(matrix[i]) # find the good assignment

    dummy_classes = (classes+1)*100
    for i in range(values):
        classes[dummy_classes == (i+1)*100] = positions[i] # reassign the labels

    n_classes = len(np.unique(real_classes))
    
    ov = (sum(classes == real_classes)/n - 1/n_classes)/(1-1/n_classes) # compute the overlap

    return ov
    
############################################################################################################################


def compute_modularity(A, estimated_labels):
    
    '''Function to compute the modularity of a given partition on a network with n nodes
    Use: 
        mod =  compute_modularity(A, estimated_labels)
        
    Output:
        mod (scalar) : modularity of the assignment
    
    Input:
        A (sparse matrix of size n x n) : adjacency matrix of the network
        estimated_labels (array of size n) : vector containing the assignment of the labels
    '''
    
    d = np.array(np.sum(A, axis = 0))[0] # degree vector
    m = sum(d) # 2|E|
    n_clusters = len(np.unique(estimated_labels)) # number  of clusters
    mod = 0
    for i in range(n_clusters):
        I_i = (estimated_labels == i)*1 # indicator vector : the entry j equal to 1 if node j belongs to class i
        mod += I_i@A@I_i - (d@I_i)**2/m 
        
    return mod/m

############################################################################################################################

def informative_eigs(H):
    
    ''' Function that computes the informative eigenvectors of H
    
    Use : info, X = informative_eigs(H)
    
    Output : info(scalar) : number of negative eigenvalues of H
           : X (array) : in the columns of X the eigenvectors of H with negative eigenvalues are stores, except the first one
           
    Input  : H(sparse matrix) : matrix H
    
    '''
    
    flag = 0
    counter = 2
    while flag == 0:
        v, X = scipy.sparse.linalg.eigsh(H, k = counter, which = 'SA') # find all the eigenvectors with a negative eigenvalue
        if max(v) > 0:
            flag = 1
            counter = counter-1
        else:
            counter += 1

    idx = np.argsort(v)
    X = X[:,idx]
    X = X[:,1:-1]

    return counter, X

############################################################################################################################



def estimate_k(X, info, n,A):
    
    ''' Function that estimates the number of communities using the silhouettes method
    
    Use : n_clusters = estimate_k(X, info, n, A)
    
    Output : n_clusters (sclar) : number of clusters detected
    
    Input : X (array) : array containing the informative eigenvectors of H projected on the hypersphere
          : info (scalar) : maximal number of communities 
          : n (scalar) : number of nodes
          : A (sparse matrix) : adjacency matrix referring to the embedding X
          
    '''

    G = nx.from_scipy_sparse_matrix(A)
    nx.connected_component_subgraphs(G) 
    G = max(nx.connected_component_subgraphs(G), key=len) # find the giant component of G_{t=1}
    nodes = np.array(G.nodes)

    range_n_clusters = np.arange(2,info+1)
    silhouette = np.zeros(len(range_n_clusters))
    Y = X[nodes]
    i = 0

    for k in range_n_clusters:
        clusterer = KMeans(n_clusters=k)
        cluster_labels = clusterer.fit_predict(Y)
        if len(np.unique(cluster_labels)) < k:
            silhouette[i] = 0
        else:
            silhouette_avg = silhouette_score(Y, cluster_labels)
            silhouette[i] = silhouette_avg
        i += 1
        
    return range_n_clusters[np.argmax(silhouette)]


##################################################################################################################


def JL_poly_approx_eigs(H):
    
    ''' 
    Function to compute Y, the polynomial approximation of X@X.T@R where X concatenate all the eigenvectors of H associated to negative eigenvalues and R is a random projection matrix of rank O(log(nT))
    
    Use : Y = JL_poly_approx_eigs(H)
    
    Output : Y (array of size nT x r) : matrix Y
    
    Input : H (sparse array of size nT x nT) : matrix H
    
    
    ''' 
    p = 50 # order of the polynomial approximation
    nT = np.shape(H)[0] # extract nT
    n_RF = np.ceil(10 * np.log(nT)) # dimension of random projection
    mu_min = scipy.sparse.linalg.eigsh(H, k = 1, which = 'SA', return_eigenvectors=False)[0] #smallest eigenavlue
    if mu_min > 0:
        raise ValueError('H does not contain any eigenvalue below 0!')
    mu_max = scipy.sparse.linalg.eigsh(H, k = 1, which = 'LM', return_eigenvectors=False)[0] #largest eigenvalue
    H_translated = H - mu_min * scipy.sparse.diags(np.ones(nT), 0) #make H SPD
    
    R = np.random.normal(0,1/np.sqrt(n_RF),(nT,int(n_RF))) # random matrix such that E(R @ R.T)I
    jch = compute_jackson_cheby_coeff([0, np.abs(mu_min)], [0, mu_max-mu_min], p) # compute coefficients of Jackson-Chebychev polynomial approximation of the step function (=1 between 0 and abs(mu_min) and =0 afterwards)
    Y = cheby_op(H_translated, mu_max-mu_min, jch, R) # Compute the (polynomial) approximation of X @ X.T @ R where X concatenates all the eigenvectors of H_translated that are between 0 and abs(mu_min)
    
    return Y

##################################################################################################################

def compute_jackson_cheby_coeff(filter_bounds, delta_lambda, m):
    '''
    Function that computes the m+1 coefficients of the Jackson-Chebischev polynomial approximation of an ideal band-pass between between filter_bounds[0] and filter_bounds[1], for a range of values between delta_lambda[0] and delta_lambda[1]
    
    Use : jch (array of size m+1) : vectors contaiing the JC coefficients
    
    Input : filter_bounds (array of size 2) : filter_bounds[0] and filter_bounds[1] are the edges of the filter
          : delta_lambda (array of size 2) : the polynomial approximation is valid between delta_lambda[0] and delta_lambda[1]
          : m (scalar) : m+1 is the number of computed coefficients
    '''
   
    # Parameters check
    if delta_lambda[0] > filter_bounds[0] or delta_lambda[1] < filter_bounds[1]:
        raise ValueError("Bounds of the filter are out of the lambda values")
    elif delta_lambda[0] > delta_lambda[1]:
        raise ValueError("lambda_min is greater than lambda_max")

    # Scaling and translating to standard cheby interval
    a1 = (delta_lambda[1]-delta_lambda[0])/2
    a2 = (delta_lambda[1]+delta_lambda[0])/2

    # Scaling bounds of the band pass according to lrange
    filter_bounds[0] = (filter_bounds[0]-a2)/a1
    filter_bounds[1] = (filter_bounds[1]-a2)/a1

    # First compute cheby coeffs
    ch = np.arange(float(m+1))
    ch[0] = (2/(np.pi))*(np.arccos(filter_bounds[0])-np.arccos(filter_bounds[1]))
    for i in np.arange(m)+1:
        ch[i] = (2/(np.pi * i)) * \
            (np.sin(i * np.arccos(filter_bounds[0])) - np.sin(i * np.arccos(filter_bounds[1])))

    # Then compute jackson coeffs
    jch = np.arange(float(m+1))
    alpha = (np.pi/(m+2))
    for i in np.arange(m+1):
        jch[i] = (1/np.sin(alpha)) * \
            ((1 - i/(m+2)) * np.sin(alpha) * np.cos(i * alpha) +
             (1/(m+2)) * np.cos(alpha) * np.sin(i * alpha))

    # Combine jackson and cheby coeffs
    jch = ch * jch

    return jch

##################################################################################################################

def cheby_op(H_translated, lmax, jch, signal):
    '''
    Function that computes the Chebyshev polynomial of matrix H_translated applied to the vectors stacked in the columns of the matrix 'signal'.
    
    Use : r (array) : the result of the filtering
     
    '''
    nT = np.shape(H_translated)[0]
    c = np.atleast_2d(jch)
    Nscales, M = c.shape
    
    Nv = np.shape(signal)[1]
    r = np.zeros((nT * Nscales, Nv))

    if M < 2:
        raise TypeError("The coefficients have an invalid shape")

    a_arange = [0, lmax]

    a1 = float(a_arange[1] - a_arange[0]) / 2.
    a2 = float(a_arange[1] + a_arange[0]) / 2.

    twf_old = signal
    twf_cur = (H_translated.dot(signal) - a2 * signal) / a1

    tmpN = np.arange(nT, dtype=int)
    for i in range(Nscales):
        r[tmpN + nT*i] = 0.5 * c[i, 0] * twf_old + c[i, 1] * twf_cur

    factor = 2/a1 * (H_translated - a2 * scipy.sparse.diags(np.ones(nT), 0))
    for k in range(2, M):
        twf_new = factor.dot(twf_cur) - twf_old
        for i in range(Nscales):
            r[tmpN + nT*i] += c[i, k] * twf_new

        twf_old = twf_cur
        twf_cur = twf_new

    return r


