import sys
import numpy as np
import pickle as pck
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm


from sklearn.metrics import pairwise_distances

sys.path.append('../utils/')
from DiffeomorphicTopologicalLayer import DiffeomorphicTopologicalLayer, get_deformation

from autoencoder_template import create_autoencoder_model

latent_data_pts = sys.argv[1]
dataset_path = sys.argv[2]
model_path = sys.argv[3]
model_type = sys.argv[4]
correct_model = bool(int(sys.argv[5]))
kernel_bandwidth = float(sys.argv[6])
learning_rate = float(sys.argv[7])

latent_try = np.array([[float(nn) for nn in n.split('_')] for n in latent_data_pts[1:-1].split('__')])

dataset = np.load(open(dataset_path, 'rb')) # array of size n_pts x dim
# Assumption: another model with the same architecture (i.e. also created with create_model()) has been trained, and its weights are stored in checkpoint_path
if model_type[-3:] == '_2d':
  if int(dataset_path.split('/')[1].split('_')[1]) < 21:
    dataset = np.reshape(dataset, [-1, 128, 128, 1])
  else:
    dataset = np.reshape(dataset, [-1, 448, 416, 1])
  model = create_autoencoder_model(dataset.shape[1:], 2, 0., 0., model_type)
  if model_type[:3] == 'vae':
    model.load_weights(model_path + '.weights.h5')
    # Generate latent space
    latent_dataset = tf.cast(model.encoder(dataset)[0], dtype=tf.float64)
  else:
    model.load_weights(model_path + '.ckpt')
    # Generate latent space
    latent_dataset = tf.cast(model.encoder(dataset), dtype=tf.float64)

if correct_model:

  new_data = model.decoder(latent_try)

  if model_type[:3] == 'vae':
    for p in range(len(new_data)):
      plt.figure()
      plt.imshow(np.reshape(new_data[p,:], [128,128]), cmap='gist_gray')
      plt.axis('off')
      plt.tight_layout()
      plt.show()

  print('applying diffeo...')
  diffeo = pck.load(open(model_path + '---diffeos.pkl', 'rb'))
  list_grads, list_X = diffeo['diffeos'], diffeo['latents']

  plt.figure(figsize=(50,50))
  plt.scatter(list_X[-1][:,0], list_X[-1][:,1], label='Optimized LS')

  trajs = []
  for idx in tqdm(range(len(list_grads))):

    scalar = False
    kernelA = (1/(2*np.pi*kernel_bandwidth)) * np.exp(-pairwise_distances(list_X[-1-idx])**2/(2*kernel_bandwidth**2))
    kernelB = (1/(2*np.pi*kernel_bandwidth)) * np.exp(-pairwise_distances(latent_try, list_X[-1-idx])**2/(2*kernel_bandwidth**2))
    G = list_grads[-1-idx]

    idxs = np.argwhere(np.linalg.norm(G, axis=1) > 1e-4).ravel()
    if len(idxs) > 0:
      grads = get_deformation(scalar, G, idxs, kernelA, kernelB)
    else:
      grads = np.zeros(shape=latent_try.shape)

     

    trajs.append(latent_try)
    for p in range(len(new_data)):
      plt.arrow(latent_try[p,0], latent_try[p,1], learning_rate * grads[p,0], learning_rate * grads[p,1])
    latent_try = latent_try + learning_rate * grads

  trajs = np.vstack(trajs)
  plt.scatter(trajs[:,0], trajs[:,1], c='green', s=5, label='Trajectories')
  plt.scatter(list_X[0][:,0], list_X[0][:,1], label='Initial LS')
  plt.grid()
  plt.legend()
  plt.title('Generating new data by applying diffeomorphism')
  plt.show()

new_data = model.decoder(latent_try)

if model_type[:3] == 'vae':
  for p in range(len(new_data)):
    plt.figure()
    plt.imshow(np.reshape(new_data[p,:], [128,128]), cmap='gist_gray')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
