import random
import numpy as np
import tensorflow as tf
import gudhi as gd
from sklearn.metrics import pairwise_distances
from gudhi.tensorflow import RipsLayer, LowerStarSimplexTreeLayer

try:
    import sys
    from utils.pto import path_to_oineus
    sys.path.append(path_to_oineus)
    import oineus as oin
except:
    print("Oineus not found.")


def RBFKernel(X, sigma=1.0, normalized=True):
    d = X.shape[1]
    K = np.exp(-pairwise_distances(X) ** 2 / (2 * sigma ** 2))
    Kout = K if normalized == False else K / (2 * np.pi * sigma ** 2) ** (d / 2)
    return Kout

def interp(K, A):
    return np.linalg.inv(K).dot(A)

def get_warps(K):
    _, v = np.linalg.eig(np.linalg.inv(K))
    return v

def extend(M, k=2):
    n, m = M.shape[0], M.shape[1]
    A = [[0 for _ in range(m)] for _ in range(n)]
    for i in range(n):
        for j in range(m):
            A[i][j] = M[i, j] * np.eye(k)
    K = np.block(A)
    return K


def get_deformation(scalar, gradients, idxs, kernelA, kernelB):

    d = gradients.shape[1] if scalar == False else 1    
    # The corresponding descent direction (opposite of gradient, scaled by learning rate).
    A = gradients[idxs]  # shape (q x d)
    # The kernel on the non-zero entries
    K = extend(kernelA[idxs,:][:,idxs], k=d) if d > 1 else kernelA[idxs,:][:,idxs] # shape (q x d) x (q x d)
    # Compute K_inv.dot(A)
    alpha = interp(K, A.reshape([-1, 1]))  # shape q x d
    # The kernel between the whole point cloud and the reference points
    KG = extend(kernelB[:,idxs], k=d) if d > 1 else kernelB[:,idxs]  # shape (n x d) x (q x d)
    # Apply this kernel to the alpha previously computed, which gives our global movement.
    MV = KG.dot(alpha)
    if scalar == False:
        MV = MV.reshape([-1, d])  # shape (n x d)

    return MV

# new layer whose forward call is identity and backward modifies a given gradient
# calling tape.gradient wrt output does not apply the kernel and thus corresponds to vanilla
def _build_TopologicalFlow(kernel: np.array, filtration_type: str):
    @tf.custom_gradient
    def TopologicalFlow_(X):
        def grad(G):
            G_numpy = tf.convert_to_tensor(G).numpy() #tf.sparse.to_dense(tf.sparse.SparseTensor(G.dense_shape, G.indices, G.values))
            idxs = np.argwhere(np.linalg.norm(G_numpy, axis=1) > 1e-6).ravel()
            if len(idxs) > 0:
                scalar = True if filtration_type == 'LowerStar' else False
                gradient = get_deformation(scalar, G_numpy, idxs, kernel, kernel)
            else:
                gradient = G_numpy
            return tf.constant(gradient, dtype=tf.float64)
        return X, grad
    return TopologicalFlow_

def _build_OineusTopologicalLayer(filtration_type: str,
                                  homology_dimension: int,
                                  max_edge_length: float,            # if filtration_type == 'Rips'
                                  simplex_tree: gd.SimplexTree,      # if filtration_type == 'LowerStar'
                                  subsample_size: int,
                                  n_preserved: int):

    @tf.custom_gradient
    def OineusTopologicalLayer_(X):
        # Construction with Oineus' API
        if filtration_type == 'Rips':
            X = tf.cast(X, dtype=tf.float64)
            fil, longest_edges = oin.get_vr_filtration_and_critical_edges(np.array(X.numpy(), dtype=np.float64), max_dim=homology_dimension+1, max_radius=max_edge_length, n_threads=1)
            top_opt = oin.TopologyOptimizer(fil)
            eps = top_opt.get_nth_persistence(homology_dimension, n_preserved)
            # We can potentially use other losses from Oineus
            # The birth-birth loss is equivalent to "ul.death_killer": all (but n_preserved-1 = 0) points withs
            # coordinates (b,d) want to be matched to (b,b).
            # Below, indices: simplices we want to update, and values: values we want to assign to them.
            indices, values = top_opt.simplify(eps, oin.DenoiseStrategy.BirthBirth, homology_dimension)
            # Now, to each simplex to be updated, we assign a critical set (list of indices) to which we may want to assign
            # the same value. The following is thus a list of pairs [(values, indices) ...].
            critical_sets = top_opt.singletons(indices, values)
            # As a given index could appear twice (or more), we need to chose which value to actually assign.
            # The heuristic is to take the maximum. In the following, we eventually store the list of indices to be
            # updated, and their corresponding values.
            crit_indices, crit_values = top_opt.combine_loss(critical_sets, oin.ConflictStrategy.Max)

            # Convert filtration values of simplices into their Rips edges. Indeed, recall that indices correspond to
            # (critical) simplices, i.e. introduction of edges (creating or killing circles).
            crit_indices = np.array(crit_indices, dtype=np.int32)
            crit_edges = longest_edges[crit_indices, :]
            # Two list of pairs (x,y) inducing critical edges.
            crit_edges_x, crit_edges_y = crit_edges[:, 0], crit_edges[:, 1]

            # Now we store all the values we need in what we call the "finite dgm", but beware, it's not a dgm actually.
            # It's an array of size (n) x (2 D + 1) where D is the dimension of the point cloud and n the number of
            # critical points to move.
            # So the first :D coordinates correspond to a point x to move, D:2D correspond to its corresponding y,
            # and the final 2D coordinate is the value we want to assign to this pair.
            finite_dgm = tf.concat([tf.gather(X, crit_edges_x), tf.gather(X, crit_edges_y), tf.Variable(np.array(crit_values)[:, None], dtype=tf.float64, trainable=False)], axis=1)

        elif filtration_type == 'LowerStar':
            raise ValueError('LowerStar with Oineus has not been implemented yet.')

        def grad(dd):

            gradient = np.zeros(shape=X.numpy().shape)

            num_pts = dd.shape[0]

            if filtration_type == 'Rips':
                input_dimension = X.shape[1]
                gradient[crit_edges_x, :] += dd[:, 0:input_dimension]
                gradient[crit_edges_y, :] += dd[:, input_dimension:2 * input_dimension]

            elif filtration_type == 'LowerStar':
                raise ValueError('LowerStar with Oineus has not been implemented yet.')

            return tf.constant(gradient, dtype=tf.float64)

        return finite_dgm, grad

    return OineusTopologicalLayer_


class DiffeomorphicTopologicalLayer(tf.keras.layers.Layer):
    def __init__(self,
                 filtration_type: str,
                 homology_dimensions: int,
                 max_edge_length: float = 10.,             # if filtration_type == 'Rips'
                 simplex_tree: gd.SimplexTree = None,      # if filtration_type == 'LowerStar'
                 use_deformations: bool = False,
                 kernel: np.array = None,
                 subsample_size: int = None,
                 use_oineus: bool = False,
                 n_preserved: int = 1,
                 verbose: str = False):
        super(DiffeomorphicTopologicalLayer, self).__init__()

        self.filtration_type, self.homology_dimensions, self.use_deformations, self.use_oineus, self.subsample_size = filtration_type, homology_dimensions, use_deformations, use_oineus, subsample_size
        if self.filtration_type == 'Rips':
            self.max_edge_length = max_edge_length
            if self.use_oineus:
                self.n_preserved = n_preserved
                self.topological_layer = _build_OineusTopologicalLayer(self.filtration_type, self.homology_dimensions[0], self.max_edge_length, self.simplex_tree, self.subsample_size, self.n_preserved)
            else:
                self.topological_layer = RipsLayer(maximum_edge_length=self.max_edge_length, homology_dimensions=self.homology_dimensions)
        elif self.filtration_type == 'LowerStar':
            self.simplex_tree = simplex_tree
            if self.use_oineus:
                self.n_preserved = n_preserved
                print('Not implemented yet---using vanilla gradient instead')
            self.topological_layer = LowerStarSimplexTreeLayer(simplextree=self.simplex_tree, homology_dimensions=self.homology_dimensions)
        if self.use_deformations:
            self.kernel = kernel
            self.topological_flow = _build_TopologicalFlow(kernel=self.kernel, filtration_type=self.filtration_type)

    def call(self, X):

        if self.use_deformations:
            self.X_flow = self.topological_flow(X)
        else:
            self.X_flow = X

        if self.subsample_size != -1:
            N = self.X_flow.shape[0]
            # Pick indices between 1 and N, N being the size of X_input.
            ech = random.choices(range(N), k=self.subsample_size)
            if self.filtration_type == 'Rips':
                # build a subsample point cloud
                self.X_sub_flow = tf.gather(self.X_flow, ech)
            elif self.filtration_type == 'LowerStar':
                # build a subsample of filtration values
                self.X_sub_flow = tf.gather(self.X_flow, ech)
                raise ValueError('Subsample of simplex tree not implemented yet.')
        else:
            N = X.shape[0]
            ech = np.arange(N)
            self.X_sub_flow = self.X_flow

        dgm = self.topological_layer(tf.cast(self.X_sub_flow, tf.float32))

        return dgm

