import tensorflow as tf
import numpy as np
import random


'''
Convention: loss functions should take as input a single diagram and a point cloud X, even if one is not used. 
It makes things more easily interchangeable. 
'''


def death_killer(dgm, X=None):
    """
    Akin total persistence but only penalizing the death time (more collapse, less spread in cycles). X is unused.
    """
    return tf.math.reduce_sum(tf.square((dgm[:, 1])))


def loss_oineus(dgm, X=None):
    cst = tf.Variable(dgm[:,0], trainable=False)
    with tf.GradientTape() as tape:
        loss = tf.math.reduce_sum(tf.square((dgm[:, 1] - cst)))
    return loss


def oineus_helper_loss(out_oineus, X):
    """
    Compute the equivalent of death-killer loss for oineus.
    Recall that for oineus, to each point in X (dimension D), we assign a target location.
    This is encoded by a out_oineus point cloud of size N x (2D+1) : a pair (x_1, x_2) in X, and a target value.
    """
    N, D = X.shape
    return tf.math.reduce_sum(tf.abs(
                                    tf.norm((out_oineus[:, 0:D] - out_oineus[:, D:2 * D]) ** 2, axis=1)
                                    - out_oineus[:, 2 * D] ** 2))


def total_persistence(dgm, X):
    """
    Penalize the total persistence of a diagram. X is unused.
    """
    return tf.math.reduce_sum(tf.square(0.5 * (dgm[:, 1] - dgm[:, 0])))


def collapser(dgm, X):
    """
    Penalize both birth and death time, so technically things should collapse even faster than using death killer.
    """
    return tf.math.reduce_sum(tf.square((dgm[:, 1] + dgm[:, 0])))


def maxpers(dgm, X):
    """
    Taken from optimization paper by Carriere et al. Try to maximize persistence, but staying in [0,1]^d.
    both the diagram and X are used.
    """
    persistence_loss = -tf.math.reduce_sum(
        tf.square(.5 * (dgm[:, 1] - dgm[:, 0])))  # work better without the birth term
    regularization = tf.reduce_sum(tf.maximum(tf.abs(X) - 1., 0))
    loss = persistence_loss + regularization
    return loss


def validation_loss(X, layer, loss_function, subsample_size, n_repeat=None):
    """
    In order to mitigate the variance due to subsampling when it comes to minimize the loss,
    This is not an evaluation of the true loss, but a better estimation of the expected loss of the subsample.

    We use the "batch" approach of SGD (even though this is a bit unjustified I think).
    That is, if X as N points, and we subsample n points, we do n_repeat = N // n repeat.
    """
    res = []
    N = X.shape[0]
    if n_repeat is None:
        n_repeat = N // subsample_size
    for it in range(n_repeat):
        ech = random.choices(range(N), k=subsample_size)
        # build a subsample point cloud Xech
        Xech = tf.gather(X, ech)
        dgm = layer(Xech)
        loss = loss_function(dgm, Xech)
        res.append(float(loss.numpy()))
    return np.mean(res)


def bunny_loss(dgm, X):
    persistence_loss = -tf.math.reduce_sum(
        tf.square(.5 * (dgm[:, 1] - dgm[:, 0])))
    regularization = tf.reduce_sum(tf.maximum(tf.abs(X) - 1., 0))
    loss = persistence_loss + regularization
    return loss