import sys
import numpy as np
import pickle as pck
import matplotlib.pyplot as plt
from time import time
import tensorflow as tf
import gudhi as gd
import gudhi.wasserstein as gwass
from tqdm import tqdm

from sklearn.metrics import pairwise_distances

sys.path.append('../utils/')
from DiffeomorphicTopologicalLayer import DiffeomorphicTopologicalLayer, get_deformation

from autoencoder_template import create_autoencoder_model

dataset_path = sys.argv[1]
model_path = sys.argv[2]
model_type = sys.argv[3]
correct_model = bool(int(sys.argv[4]))
if correct_model:
    n_epochs = int(sys.argv[5])
    learning_rate = float(sys.argv[6])
    homology_dimensions = sys.argv[7]
    max_edge_length = float(sys.argv[8])
    subsample_size = int(sys.argv[9])
    use_deformations = bool(int(sys.argv[10]))
    kernel_bandwidth = float(sys.argv[11])
    use_oineus = bool(int(sys.argv[12]))
    n_preserved = int(sys.argv[13])
else:
    kernel_bandwidth = float(sys.argv[5])
    learning_rate = float(sys.argv[6])


if model_path.split('-')[0] != 'precomputed':
    dataset = np.load(open(dataset_path, 'rb')) # array of size n_pts x dim
    # Assumption: another model with the same architecture (i.e. also created with create_model()) has been trained, and its weights are stored in checkpoint_path
    if model_type[-3:] == '_2d':
        if int(dataset_path.split('/')[1].split('_')[1]) < 21:
            dataset = np.reshape(dataset, [-1, 128, 128, 1])
        else:
            dataset = np.reshape(dataset, [-1, 448, 416, 1])
    model = create_autoencoder_model(dataset.shape[1:], 2, 0., 0., model_type)
    if model_type[:3] == 'vae':
        model.load_weights(model_path + '.weights.h5')
        # Generate latent space
        latent_dataset_init = tf.cast(model.encoder(dataset)[0], dtype=tf.float64)
    else:
        model.load_weights(model_path + '.ckpt')
        # Generate latent space
        latent_dataset_init = tf.cast(model.encoder(dataset), dtype=tf.float64)

else:
    latent_dataset_init = np.load(open(dataset_path, 'rb'))

angles = np.load(open(dataset_path[:-4] + '_labels.npy', 'rb'))

# Optimize latent space...
if correct_model:

    homology_dimensions = [int(h) for h in homology_dimensions.split('-')]

    start = time()

    X = tf.Variable(initial_value=latent_dataset_init, trainable=True, dtype=tf.float64)
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

    list_X, list_loss, list_out, list_grads = [], [], [], []
    for epoch in tqdm(range(n_epochs)):

        kernel = (1/(2*np.pi*kernel_bandwidth)) * np.exp(-pairwise_distances(X.numpy())**2/(2*kernel_bandwidth**2)) if use_deformations else np.zeros([len(X.numpy()), len(X.numpy())])
        
        with tf.GradientTape(persistent=True) as tape:

            DTL = DiffeomorphicTopologicalLayer(filtration_type='Rips', homology_dimensions=homology_dimensions, max_edge_length=max_edge_length, 
                                                use_deformations=use_deformations, kernel=kernel, subsample_size=subsample_size, 
                                                use_oineus=use_oineus, n_preserved=n_preserved, verbose=False)            
            out = DTL(X)

            if use_oineus:
                loss = tf.math.reduce_sum(tf.norm(out[:,0:D] - out[:,D:2*D], axis=1) - out[:,2*D])
            else:
                finite_dgm = out[1][0]
                pers_perm = np.argsort((finite_dgm[:, 1] - finite_dgm[:, 0]).numpy()).ravel()[::-1]
                finite_dgm = tf.gather(finite_dgm, pers_perm[:1])
                loss = -tf.math.reduce_sum(tf.square(0.5 * (finite_dgm[:, 1] - finite_dgm[:, 0])))

        vanilla_gradients = tape.gradient(loss, [DTL.X_flow])
        gradients = tape.gradient(loss, [X])

        prev_loss = loss.numpy() if epoch == 0 else list_loss[-1]
        if (np.abs(loss.numpy() - prev_loss) <= 3.):
            list_out.append(out)
            list_loss.append(loss.numpy())
            list_grads.append(tf.convert_to_tensor(vanilla_gradients[0]).numpy())
            list_X.append(X.numpy())
            optimizer.apply_gradients(zip(gradients, [X]))
        else:
            _ = list_out.pop(-1)
            _ = list_X.pop(-1)
            _ = list_loss.pop(-1)
            _ = list_grads.pop(-1)
            break

    end = time()

    if model_path.split('-')[0] == 'precomputed':
        pck.dump({'diffeos': list_grads, 'latents': list_X}, open(model_path.split('-')[1] + '---diffeos.pkl', 'wb'))
    else:
        pck.dump({'diffeos': list_grads, 'latents': list_X}, open(model_path + '---diffeos.pkl', 'wb'))

    X_plot = list_X[-1]

    plt.figure()
    plt.plot(list_loss)
    plt.savefig(model_path + '---topo_loss')

# ...Or load and apply diffeo
else:

    if model_path.split('-')[0] == 'precomputed':
        diffeo = pck.load(open(model_path.split('-')[1] + '---diffeos.pkl', 'rb'))
    else:
        diffeo = pck.load(open(model_path + '---diffeos.pkl', 'rb'))
    list_grads, list_X = diffeo['diffeos'], diffeo['latents']

    plt.figure()
    plt.scatter(list_X[0][:,0], list_X[0][:,1])
    plt.scatter(list_X[-1][:,0], list_X[-1][:,1])
    plt.savefig(model_path + '---topofix_target')

    X_plot = np.copy(latent_dataset_init)

    for idx, G in tqdm(enumerate(list_grads)):

        scalar = False
        kernelA = (1/(2*np.pi*kernel_bandwidth)) * np.exp(-pairwise_distances(list_X[idx])**2/(2*kernel_bandwidth**2))
        kernelB = (1/(2*np.pi*kernel_bandwidth)) * np.exp(-pairwise_distances(X_plot, list_X[idx])**2/(2*kernel_bandwidth**2))

        idxs = np.argwhere(np.linalg.norm(G, axis=1) > 1e-4).ravel()
        if len(idxs) > 0:
            grads = get_deformation(scalar, G, idxs, kernelA, kernelB)
        else:
            grads = G

        X_plot = X_plot - learning_rate * grads

np.save(open(dataset_path[:-4] + '---' + model_path.split('/')[1] + '---after_topofix.npy', 'wb'), X_plot)

units = (X_plot - np.mean(X_plot, axis=0))/(np.linalg.norm(X_plot - np.mean(X_plot, axis=0), axis=1)[:,None])
units_init = (latent_dataset_init - np.mean(latent_dataset_init, axis=0))/(np.linalg.norm(latent_dataset_init - np.mean(latent_dataset_init, axis=0), axis=1)[:,None])
units, units_init = np.sum(np.multiply(units, units[0:1,:]), axis=1), np.sum(np.multiply(units_init, units_init[0:1,:]), axis=1) 

corr_circle, corr_circle_init = np.corrcoef(angles, np.arccos(np.maximum(np.minimum(units, 0.99999), -0.99999))), np.corrcoef(angles, np.arccos(np.maximum(np.minimum(units_init, 0.99999), -0.99999)))
plt.figure()
plt.scatter(latent_dataset_init[:,0], latent_dataset_init[:,1], c=angles, s=10+30*np.arccos(units_init), marker="o")
plt.scatter(X_plot[:,0], X_plot[:,1], c=angles, s=10+30*np.arccos(units), marker="D")
plt.axis('equal')
plt.title(str(corr_circle[0,1]) + ' vs. ' + str(corr_circle_init[0,1]))
plt.savefig(dataset_path[:-4] + '---' + model_path.split('/')[1] + '---after_topofix')
