"""Utilities for effective resistance, resistance embeddings, etc."""
from numpy import linalg
from typing import Any, Optional

import math
import numpy as np
import scipy as sp

import bourgain_embedding


def interleave_columns_of_matrices(first: np.ndarray,
                                   second: np.ndarray) -> np.ndarray:
  """Interleaves the columns of given two matrices and returns the resulting matrix.

  The number of rows should be the same.

  Args:
    first: k-by-m matrix.
    second: k-by-n matrix.

  Returns:
    Interleaved matrix.
  """
  k = first.shape[0]
  m = first.shape[1]
  n = second.shape[1]
  assert k == second.shape[0], 'Matrices should have the same # of rows.'
  result = np.zeros((k, max(m, n) * 2))

  for i in range(0, m):
    result[:, 2 * i] = first[:, i]

  for j in range(0, n):
    result[:, 2 * j + 1] = second[:, j]

  return result


def count_edges(adj_mat: np.ndarray) -> int:
  nb_edges = 0
  for i in range(adj_mat.shape[0]):
    for j in range(adj_mat.shape[1]):
      if adj_mat[i, j] > 0.0:
        nb_edges += 1
  return nb_edges


def incidence_matrix(n: int, senders: np.ndarray, receivers: np.ndarray) -> Any:
  """Creates the edge-node incidence matrix for the given edge list.

  The edge list should be symmetric, and there should not be any isolated nodes.

  Args:
    n: The number of nodes in the graph.
    senders: The sender nodes of the graph.
    receivers: The receiver nodes of the graph.

  Returns:
    A sparse incidence matrix
  """

  m = senders.shape[0]
  rows = list(range(m)) + list(range(m))
  cols = senders.tolist() + receivers.tolist()
  vals = [-1.0] * m + [+1.0] * m

  return sp.sparse.csc_matrix((vals, (rows, cols)), shape=(m, n))


def incidence_matrix_rowcol(senders: np.ndarray, receivers: np.ndarray) -> Any:
  """Returns row list and col list for incidence matrix.

  Args:
    senders: The sender nodes of the graph.
    receivers: The receiver nodes of the graph.

  Returns:
    A sparse incidence matrix
  """
  m = senders.shape[0]
  rows = list(range(m)) + list(range(m))
  cols = senders.tolist() + receivers.tolist()
  return rows, cols


def sqrt_conductance_matrix(senders: np.ndarray, weights: np.ndarray) -> Any:
  """Creates the square root of conductance matrix."""
  m = senders.shape[0]
  rows = list(range(m))
  vals = np.sqrt(weights / 2.0)
  return sp.sparse.csc_matrix((vals, (rows, rows)), shape=(m, m))


def laplacian_matrix(n: int,
                     senders: np.ndarray,
                     receivers: np.ndarray,
                     weights: np.ndarray,
                     normalized: bool = False) -> Any:
  """Creates the laplacian matrix for given edge list.

  The edge list should be symmetric, and there should not be any isolated nodes.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    weights: The weights of the edges

  Returns:
    A sparse Laplacian matrix
  """
  if weights is None:
    weights = 0 * senders + 1

  s = senders.tolist() + list(range(n))
  t = receivers.tolist() + list(range(n))
  w = weights.tolist() + [0.0] * n
  adj_mat = sp.sparse.csc_matrix((w, (s, t)), shape=(n, n))
  degrees = np.ravel(adj_mat.sum(axis=0))
  degrees[degrees < 1e-7] = 1.0

  if normalized:
    inv_sqrt_degrees_mat = sp.sparse.csc_matrix(
        (degrees**-0.5, (range(0, n), range(0, n))), shape=(n, n))
    lap_mat = sp.sparse.eye(
        n) - inv_sqrt_degrees_mat * adj_mat * inv_sqrt_degrees_mat
  else:
    degrees_mat = sp.sparse.csc_matrix((degrees, (range(0, n), range(0, n))),
                                       shape=(n, n))
    lap_mat = degrees_mat - adj_mat
  return lap_mat


def effective_resistance_embedding(n: int,
                                   senders: np.ndarray,
                                   receivers: np.ndarray,
                                   weights: Optional[np.ndarray] = None,
                                   accuracy: np.double = 0.1,
                                   which_method: int = 0) -> Any:
  """Computes the vector-valued resistive embedding (as opposed to scalar-valued functions along edges provided by the effective_resistances function below) for given graph up to a desired accuracy.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    weights: The weights of the edges
    accuracy: Target accuracy
    which_method: 0 => choose the most suitable +1 => use random projection
      (approximates effective resistances) -1 => use pseudo-inverse

  Returns:
    Effective resistances embedding (each row corresponds to a node)
  """
  m = senders.shape[0]
  if weights is None:
    weights = np.ones(m)

  lap_mat = laplacian_matrix(n, senders, receivers, weights)

  # number of required dimensions is 8 * ln(m)/accuracy^2 if we
  # do random-projection.
  if m == 0:
    k = 1
  else:
    k = math.ceil(8 * math.log(m) / (accuracy**2))

  b_mat = incidence_matrix(n, senders, receivers)
  c_sqrt_mat = sqrt_conductance_matrix(receivers, weights)

  # in case of random projection, we need to invert k vectors. if k = Omega(n),
  # it is simply better to (pseudo-)invert the whole laplacian.
  if which_method == -1 or (k >= n / 2 and which_method != +1):
    inv_lap_mat = np.linalg.pinv(lap_mat.todense(), hermitian=True).A

    embedding = (c_sqrt_mat * b_mat * inv_lap_mat).transpose()
  else:
    print('  doing JL.')
    # U C^{1/2} B L^{-1} same as U' L^{-1/2}
    embedding = sp.zeros((n, k))
    for i in range(k):
      y = sp.random.normal(0.0, 1.0 / math.sqrt(k), (1, m))
      y = y * c_sqrt_mat * b_mat
      embedding[..., i], _ = sp.sparse.linalg.cg(lap_mat, y.transpose())
      if i % 100 == 0 or i + 1 == k:
        print('{0}/{1} done.'.format(i, k))

  return embedding


def weird_embedding(n: int,
                    senders: np.ndarray,
                    receivers: np.ndarray,
                    weights: Optional[np.ndarray] = None) -> Any:
  """Computes the vector-valued weird embedding (as opposed to scalar-valued functions along edges provided by the effective_resistances function below) for given graph up to a desired accuracy.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    weights: The weights of the edges

  Returns:
    Effective resistances embedding (each row corresponds to a node)
  """
  m = senders.shape[0]
  if weights is None:
    weights = np.ones(m)

  lap_mat = laplacian_matrix(n, senders, receivers, weights)
  inv_lap_mat = np.linalg.pinv(lap_mat.todense(), hermitian=True).A

  if inv_lap_mat[0, 0] >= 1e10:
    raise NameError('pinv failed.')
  return inv_lap_mat


def effective_resistances_from_embedding(
    n: int,
    embedding: np.ndarray,
    senders: np.ndarray,
    receivers: np.ndarray,
    weights: Optional[np.ndarray] = None,
    normalize_per_node: bool = False) -> Any:
  """Computes the effective resistances for given graph using the given embedding.

  The input edges does not have to be symmetric.
  Args:
    n: The number of nodes in the graph
    embedding: The effective resistance embedding
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    weights: Edge weights. NOT USED for ER computation, but here only for
      consistence.
    normalize_per_node: If true, will normalize the er's so that the sum for
      each node is 1.

  Returns:
    Effective resistances.
  """

  m = senders.shape[0]
  if weights is None:
    weights = np.ones(m)

  ers = np.zeros(m)
  for i, u, v in zip(range(m), senders, receivers):
    diff = embedding[u, :] - embedding[v, :]
    er = (diff**2).sum()
    ers[i] = er
    if (i % 1000 == 0 or i + 1 == m) and m >= 10000:
      print('{0}/{1} done.'.format(i, m))

  if normalize_per_node:
    sums = sp.zeros((n, 1))
    for _, t, er in zip(senders, receivers, ers):
      sums[t] += er

    for i in range(m):
      ers[i] /= sums[receivers[i]]

  return ers


def laplacian_eigenv(n: int,
                     senders: np.ndarray,
                     receivers: np.ndarray,
                     weights: Optional[np.ndarray] = None) -> np.ndarray:
  """Computes the eigenvector difference between senders and receivers.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    weights: The weights of the edges

  Returns:
    Eigenvector differences as edge features.
  """
  m = senders.shape[0]
  if weights is None:
    weights = np.ones(m)

  lap_mat = laplacian_matrix(n, senders, receivers, weights)
  # rows of eigenv correspond to graph nodes, cols correspond to eigenvalues
  eigenval, eigenv = linalg.eig(lap_mat.todense())
  nonzero_idx = np.nonzero(eigenval)[0]
  # remove eigenvectors that correspond to zero eigenvalues
  eigenval = eigenval[nonzero_idx]
  eigenv = eigenv[:, nonzero_idx]
  # sort eigenvectors in ascending order of eigenvalues
  sorted_idx = np.argsort(eigenval)
  eigenv = eigenv[:, sorted_idx]
  real_eigenv = np.real(eigenv)

  features = np.array([
      real_eigenv[s, :] - real_eigenv[r, :] for s, r in zip(senders, receivers)
  ])
  features = np.squeeze(features)
  return features


def get_edges_and_features(adj_mat, node_offset):
  """Get the effective resistance and edge attributes for a graph."""
  assert len(adj_mat.shape) == 2
  senders = []
  receivers = []
  weights = []

  unoffset_senders = []
  unoffset_receivers = []

  # Add list of in-degrees and out-degrees
  in_degs = [0] * adj_mat.shape[0]
  out_degs = [0] * adj_mat.shape[1]

  for i in range(adj_mat.shape[0]):
    for j in range(adj_mat.shape[1]):
      if adj_mat[i, j] > 0.0:
        assert adj_mat[j, i] > 0.0  # make sure we have symmetric graphs for now
        out_degs[i] += 1
        in_degs[j] += 1
        senders.append(i + node_offset)
        receivers.append(j + node_offset)
        weights.append(adj_mat[i, j])
        unoffset_senders.append(i)
        unoffset_receivers.append(j)

  distance_embedding = bourgain_embedding.compute_bourgain_embedding(
      adj_mat.shape[0], np.array(unoffset_senders),
      np.array(unoffset_receivers))

  incidence_rows, incidence_cols = incidence_matrix_rowcol(
      np.array(unoffset_senders), np.array(unoffset_receivers))
  eff_res_embedding = effective_resistance_embedding(
      adj_mat.shape[0],
      np.array(unoffset_senders),
      np.array(unoffset_receivers),
      weights=np.array(weights))
  eff_res = effective_resistances_from_embedding(
      adj_mat.shape[0],
      eff_res_embedding,
      np.array(unoffset_senders),
      np.array(unoffset_receivers),
      weights=np.array(weights))
  hitting_time = hitting_times_from_embedding(
      adj_mat.shape[0],
      eff_res_embedding,
      np.array(unoffset_senders),
      np.array(unoffset_receivers),
      weights=np.array(weights),
      normalize=True)

  hitting_time_rev = eff_res - hitting_time
  hitting_times = np.hstack(
      (np.expand_dims(hitting_time,
                      axis=1), np.expand_dims(hitting_time_rev, axis=1)))

  eigv_diffs = laplacian_eigenv(adj_mat.shape[0], np.array(unoffset_senders),
                                np.array(unoffset_receivers))
  return senders, receivers, eff_res_embedding, eff_res, hitting_times, eigv_diffs, incidence_rows, incidence_cols, in_degs, out_degs, distance_embedding


def hitting_times_from_embedding(n: int,
                                 embedding: np.ndarray,
                                 senders: np.ndarray,
                                 receivers: np.ndarray,
                                 weights: Optional[np.ndarray] = None,
                                 normalize: bool = True) -> Any:
  """Computes the hitting time for each edge from given effective resistance

     embedding. If normalize is true, each hitting time is divided by total
     weight.

     Args:
       n: # of nodes
       embedding: embedding (obtained from effective_resistance_embedding)
       senders: senders (tails of each edge)
       receivers: receivers (heads of each edge)
       weights: (optional) edge weights
       normalize: if true, each hitting time is divided by 2m. otherwise not.

     Returns:
       ht: hitting time for each edge as an np.ndarray
   """
  m = senders.shape[0]
  if weights is None:
    weights = np.ones(m)

  degrees = np.zeros(n)
  for i in range(m):
    degrees[senders[i]] += weights[i]

  two_m = np.sum(weights)

  x_norm_sq = np.sum(embedding**2, axis=1)

  sum_d_x = np.matmul(np.transpose(embedding), degrees)

  ht = np.zeros(m)

  for i, u, v in zip(range(m), senders, receivers):
    x_u = embedding[u, :]
    x_v = embedding[v, :]
    x_vu = x_v - x_u
    x_u_norm = x_norm_sq[u]
    x_v_norm = x_norm_sq[v]
    x_vu_norm = np.sum(x_vu**2)

    ht[i] = 0.5 * (
        two_m * x_vu_norm + two_m *
        (x_v_norm - x_u_norm) - 2 * np.sum(x_vu * sum_d_x))

  if normalize:
    ht = ht / two_m
  return ht


def hitting_times(n: int,
                  senders: np.ndarray,
                  receivers: np.ndarray,
                  weights: Optional[np.ndarray] = None,
                  accuracy: float = 0.1,
                  normalize: bool = True) -> Any:
  """Computes the hitting time for each edge.

     If normalize is true, each hitting time is divided by total
     weight.

     Args:
       n: # of nodes
       senders: senders (tails of each edge)
       receivers: receivers (heads of each edge)
       weights: (optional) edge weights
       accuracy: accuracy
       normalize: if true, each hitting time is divided by 2m. otherwise not.

     Returns:
       ht: hitting time for each edge as an np.ndarray
   """

  embedding = effective_resistance_embedding(
      n, senders, receivers, weights, accuracy=accuracy, which_method=+1)

  return hitting_times_from_embedding(
      n, embedding, senders, receivers, weights, normalize=normalize)


def print_matrix_for_python(mat: np.ndarray):
  """Prints a matrix in python acceptable form."""
  print('mat = [', end='')
  for i in range(0, mat.shape[0]):
    print('[', end='')
    for j in range(0, mat.shape[1]):
      print(mat[i, j], end='')
      if j + 1 < mat.shape[1]:
        print(', ', end='')
    if i + 1 < mat.shape[0]:
      print('], ')
    else:
      print(']]')
