import sys
from autoencoder_template import create_autoencoder_model
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

dataset_path = sys.argv[1]
model_path = sys.argv[2]
n_epochs = int(sys.argv[3])
learning_rate = float(sys.argv[4])
model_type = sys.argv[5]

dataset = np.load(open(dataset_path, 'rb')) 
if model_type[-3:] == '_1d':
  kl_weight = float(model_path.split('/')[1].split('_')[0].split('-')[1])
  model = create_autoencoder_model(dataset.shape[1:], 2, learning_rate, kl_weight, model_type)
else:
  if int(dataset_path.split('/')[1].split('_')[1]) < 21:
    dataset = np.reshape(dataset, [-1, 128, 128, 1])
  else:
    dataset = np.reshape(dataset, [-1, 448, 416, 1])
  kl_weight = float(model_path.split('/')[1].split('_')[0].split('-')[1])
  model = create_autoencoder_model(dataset.shape[1:], 2, learning_rate, kl_weight, model_type)
labels = np.load(open(dataset_path[:-4] + '_labels.npy', 'rb'))

latent_space = model.encoder(dataset)
if model_type[:3] != 'vae':
  plt.figure()
  plt.scatter(latent_space[:,0], latent_space[:,1], c=labels)
  plt.show(model_path + '---initial_latent')
else:
  plt.figure()
  plt.scatter(latent_space[0][:,0], latent_space[0][:,1], c=labels)
  plt.savefig(model_path + '---initial_latent')

if model_type[:3] != 'vae':
  cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_path + '.ckpt', save_weights_only=True, verbose=1)
  history = model.fit(x=dataset, y=dataset, batch_size=128, epochs=n_epochs, shuffle=True, validation_data=(dataset, dataset), callbacks=[cp_callback])
else:
  cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_path + '.weights.h5', save_weights_only=True, verbose=1)
  history = model.fit(x=dataset, batch_size=128, epochs=n_epochs, shuffle=True, callbacks=[cp_callback])

latent_space = model.encoder(dataset)
if model_type[:3] != 'vae':
  plt.figure()
  plt.scatter(latent_space[:,0], latent_space[:,1], c=labels)
  plt.savefig(model_path + '---optimized_latent')
else:
  plt.figure()
  plt.scatter(latent_space[0][:,0], latent_space[0][:,1], c=labels)
  plt.savefig(model_path + '---optimized_latent')

plt.figure()
if model_type[:3] == 'vae':
  plt.plot(history.history['total_loss'], label='Training Total Loss')
  plt.plot(history.history['kl_loss'], label='Training KL Loss')
  plt.plot(history.history['reconstruction_loss'], label='Training Reconstruction Loss')
else:
  plt.plot(history.history['loss'], label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(model_path + '---train_loss')
