# import json
# import pickle
# import math
# from time import time
import tensorflow as tf
import numpy as np
import random
import os

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, SimpleRNN, LSTM, concatenate, Dense, Reshape, InputLayer, Activation, Dropout, BatchNormalization, Conv2D, MaxPool2D, Flatten
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, Callback
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorboard import summary
from tensorflow.keras import mixed_precision

# import matplotlib.pyplot as plt
import csv

# TODO: Audit above imports!!!!

import sys
from collections import namedtuple
from scipy.stats import kendalltau

from embeddings_serialize import Image_Caption_Embedding, deserialize, serialize
from msgit_embeddings_serialize import MSGit_Embedding
from msgit_embeddings_serialize import deserialize as ge_deserialize
from vicr_parse import parse_model_params, parse_input_formula

Options = namedtuple('Options', ['model_path', 'test_path', 'train_path', 'val_path', 'val_freq', 'epochs', 'batch_size', 'learning_rate', 'decay_rate', 'decay_epochs', 'decay_steps', 'output_path', 'output_embedding_path', 'input_formula', 'new_model_params', 'embedding_type', 'k_tau', 'verbose', 'experiment_name', 'individual_ratings', 'single_rater', 'seed', 'categorical', 'is_valid'])

embedding_types = ('vilbert', 'msgit', 'clip')

LOG_DIR = './logs'

def print_short_help():
	print(f"Usage: python {sys.argv[0]} <model_path> <test_path> [Options]")
	print("")
	print(f"For further information use python {sys.argv[0]} -h")

def print_help():
	print("Evaluates the quality of image-caption pairs from extracted embeddings, or")
	print("trains a model to do so.")
	print("")
	print(f"Usage: python {sys.argv[0]} <model_path> <test_path> [Options]")
	print("")
	print("The model path specifies where the model will be saved to. It is also where")
	print("the model will be loaded from, unless the --new flag is present.")
	print("The test path is the path to the serialized embedding data on which to")
	print("evaluate the model. This should consist of a list of Image_Caption_Embedding")
	print("objects, and it can be obtained from a list of image-caption pairs by using the")
	print("vilbert-embeddings.py script.")
	print("")
	print("The output is a csv file with columns labelled 'image', 'caption', 'rating' with")
	print("one row for each embedding in the test path.")
	print("")
	print("Options:")
	print("")
	print("-t, --train <path>")
	print("  Specifies the path of the pickle file containing the embeddings to be used")
	print("  for training the model. If this option is not present, the model will only")
	print("  evaluate the embeddings in the test path")
	print("  If path is ':auto', the test path will be used with '-test' replaced with")
	print("  '-train'")
	print("-v, --validation <path>")
	print("  Specifies the path of the pickle file containing the embeddings to be used")
	print("  for validating the model during training. Only applicable if the --train")
	print("  option is present. If this option is not present, a subset of the training")
	print("  data will be used for validation")
	print("  If path is ':auto', the test path will be used with '-test' replaced with")
	print("  '-val'")
	print("-vf, --validation-frequency <number>")
	print("  Specifies the number of epochs between validation steps. Defaults to 1")
	print("-e, --epochs <number>")
	print("  Specifies the number of epochs to train for. Only applicable if the --train")
	print("  option is present. Defaults to 100")
	print("-b, --batch-size <number>")
	print("  Specifies the batch size to use during training. Only applicable if the")
	print("  --train option is present. Defaults to 200")
	print("-l, --learning-rate <number>")
	print("  Specifies the learning rate to use during training. Only applicable if the")
	print("  --train option is present. Defaults to 0.001")
	print("-d, --decay-rate <number>")
	print("  Specifies the decay rate to use during training. Only applicable if the")
	print("  --train option is present. Defaults to 0.00001")
	print("-de, --decay-epochs <number>")
	print("  Specifies how many epochs between learning rate decay during training.")
	print("  Only applicable if the --train option is present. Conflicts with -ds")
	print("  option. Defaults to 20")
	print("-ds, --decay-steps <number>")
	print("  Specifies how often to decay during training. Only applicable if the")
	print("  --train option is present. Conflicts with the -de option.")
	print("-o, --output-path <path>")
	print("  Specifies the path to the output csv file. If not provided, the default is")
	print("  output/<test_filestem>-<model_name>.csv,")
	print("  where model_name is the base name of the model given in model path without the")
	print("  extension")
	print("-n, --new <parameters>")
	print("  Prevents the model from being loaded from the model path and instead sets up a")
	print("  new model according to the supplied parameters. The parameters should be of")
	print("  the form type(shape)[activation] separated by -> for each layer. For instance,")
	print("  the following makes a 3 layer dense network:")
	print("    d(512)[relu]->drop(0.8)->d(256)[relu]->drop(0.8)->d(1)")
	print("  Here is an example for a convolutional network:")
	print("    c2d(64,3,3)[relu]->maxp2d(2,2)->c2d(32,3,3)[relu]->flat->d(128)[relu]->d(1)")
	print("  Optionally, the input formula may be specified as the first layer (see -if)")
	print("-et, --embedding-type <type>")
	print("  Specifies what type of embeddings to load")
	print(f"  Valid embedding types are {embedding_types}")
	print(f"  Defaults to '{embedding_types[0]}'' if not given")
	print("-if, --input-formula <formula>")
	print("  Specifies the recipe for taking the image and caption embeddings and producing")
	print("  the input tensor of the correct shape for the model. The formula consists of a")
	print("  sequence of image and/or caption embeddings with optional shape parameters")
	print("  separated by operators. Elementary operands are img or cap optionally followed")
	print("  by a tuple specifying the shape, such as img(32,16,2). Operators are +, -, *,")
	print("  /, and . with . meaning concatenate. The concatenate operator allows an")
	print("  optional parameter specifying the axis along which to concatenate, which is")
	print("  written inside {}, so img(32,32,1).{2}cap(32,32,1) concatenates along the")
	print("  third axis, creating a tensor of shape (32,32,2). . is the same as .{0}.")
	print("  The operators are evaluated from left to right with no order of operations,")
	print("  but parenthesized expressions are allowed and evaluated first. Here is an")
	print("  example of a complex input formula:")
	print("    (img(32,16,2).{2}cap(32,16,2))+(img(16,16,4).cap(16,16,4))")
	print("-kt, --kendall-tau <path>")
	print("  Prints the Kendall-tau score on the dataset given in path, which must specify")
	print("  a .emb file. You may use :test, :train, or :val to refer to the corresponding")
	print("  dataset. You may provide this command line option multiple times to evaluate")
	print("  Kendall-tau on several datasets")
	print("-vb, --verbose")
	print("  Prints information to the console as the model is being  trained or evaluated")
	print("-ex, --experiment-name <name>")
	print("  Sets the experiment name (default: stem of model_path)")
	print("-ir, --individual-ratings")
	print("  Uses individual ratings, rather than average ratings to train and/or evaluate")
	print("  the model")
	print("-sr, --single-rater <index>")
	print("  Uses only a single rating per image-caption pair, specified by a 0-based index.")
	print("  If an image-caption pair has fewer ratings than index, this uses the last")
	print("  rating for the pair instead")
	print("-s, --seed <int>")
	print("  Sets the seed for the random number generator. Uses keras set_random_seed().")
	print("-c, --categorical")
	print("  Makes the model use SparseCategoricalCrossentropy as a loss")
	print("-h, --help")
	print("  Flag that, if present, prints this message")

def parse_args():
	# Defaults
	model_path = ''
	test_path = ''
	train_path = ''
	val_path = ''
	val_freq = 1
	epochs = 100
	batch_size = 200
	learning_rate = 0.001
	decay_rate = 0.00001
	decay_epochs = 20
	decay_steps = None
	output_path = ''
	output_embedding_path = ''
	input_formula = None
	new_model_params = None
	embedding_type = 'vilbert'
	k_tau = []
	verbose = False
	experiment_name = None
	individual_ratings = False
	single_rater = -1
	seed = None
	categorical = False
	is_valid = True
	
	arg_index = 1
	while arg_index < len(sys.argv) and is_valid:
		arg = sys.argv[arg_index]
		if arg in ('-t', '--train'):
			arg_index += 1
			if arg_index < len(sys.argv):
				train_path = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-v', '--validation'):
			arg_index += 1
			if arg_index < len(sys.argv):
				val_path = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-vf', '--validation-frequency'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					val_freq = int(sys.argv[arg_index])
					if val_freq < 0:
						print(f"Validation frequency cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an integer")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-e', '--epochs'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					epochs = int(sys.argv[arg_index])
					if epochs < 0:
						print(f"Number of epochs cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an integer")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-b', '--batch-size'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					batch_size = int(sys.argv[arg_index])
					if batch_size < 0:
						print(f"Batch size cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an integer")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-l', '--learning-rate'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					learning_rate = float(sys.argv[arg_index])
					if learning_rate < 0:
						print(f"Learning rate cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not a float")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-d', '--decay-rate'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					decay_rate = float(sys.argv[arg_index])
					if decay_rate < 0:
						print(f"Decay rate cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not a float")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-de', '--decay-epochs'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					decay_epochs = int(sys.argv[arg_index])
					if decay_epochs <= 0:
						print(f"Decay epochs cannot be negative or zero")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an int")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-ds', '--decay-steps'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					decay_steps = int(sys.argv[arg_index])
					if decay_steps <= 0:
						print(f"Decay steps cannot be negative or zero")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an int")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-o', '--output-path'):
			arg_index += 1
			if arg_index < len(sys.argv):
				output_path = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-n', '--new'):
			arg_index += 1
			if arg_index < len(sys.argv):
				new_input_formula, new_model_params, is_valid = parse_model_params(sys.argv[arg_index])
				if is_valid:
					if input_formula is None:
						input_formula = new_input_formula
					elif new_input_formula is not None:
						print(f"Warning: Input formula specified twice. Ignoring the one specified with the new model parameters")
				else:
					print(f"Unable to parse model parameters '{sys.argv[arg_index]}'")
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-if', '--input-formula'):
			arg_index += 1
			if arg_index < len(sys.argv):
				new_input_formula, formula_len, is_valid = parse_input_formula(sys.argv[arg_index])
				if is_valid:
					if input_formula is not None:
						print(f"Warning: Input formula specified twice. Replacing with the one supplied in the --input-formula argument")
					input_formula = new_input_formula
				else:
					print(f"Unable to parse model parameters '{sys.argv[arg_index]}'")
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-et', '--embedding-type'):
			arg_index += 1
			if arg_index < len(sys.argv):
				e_type = sys.argv[arg_index]
				if e_type in embedding_types:
					embedding_type = e_type
				else:
					print(f"Invalid embedding type provided: {e_type}, defaulting to '{embedding_type}'")
					print(f"  Valid embedding types are {embedding_types}")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-kt', '--kendall-tau'):
			arg_index += 1
			if arg_index < len(sys.argv):
				k_tau_arg = sys.argv[arg_index]
				if not k_tau_arg.startswith('-'):
					k_tau.append(sys.argv[arg_index])
				else:
					print(f"Missing argument to {arg}")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-vb', '--verbose'):
			verbose = True
		elif arg in ('-ex', '--experiment-name'):
			arg_index += 1
			if arg_index < len(sys.argv):
				experiment_name = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-ir', '--individual-ratings'):
			individual_ratings = True
		elif arg in ('-sr', '--single-rater'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					single_rater = int(sys.argv[arg_index])
					if single_rater < 0:
						print(f"Index for single rater cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an int")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-s', '--seed'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					seed = int(sys.argv[arg_index])
				except ValueError:
					print(f"Argument to {arg} was not an int")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-c', '--categorical'):
			categorical = True
		elif arg in ('-h', '--help'):
			print_help()
		elif model_path == '':
			model_path = arg
		elif test_path == '':
			test_path = arg
		else:
			print(f"Unrecognized option: {arg}")
			is_valid = False
		arg_index += 1
	
	if not is_valid or model_path == '' or test_path == '':
		print_short_help()
		is_valid = False
	
	if is_valid and output_path == '':
		model_base_filename = '.'.join(model_path.split('/')[-1].split('.')[:-1])
		test_filename = '.'.join(test_path.split('/')[-1].split('.')[:-1])
		output_path = 'output/' + test_filename + '-' + model_base_filename + '.csv'
		output_embedding_path = 'output/' + test_filename + '-' + model_base_filename + '-' + embedding_type + '.emb'
	
	if is_valid and train_path == ':auto':
		split_test_path = test_path.split('/')
		train_filename = split_test_path[-1].replace('-test', '-train')
		train_path = '/'.join(split_test_path[:-1]) + '/' + train_filename
	
	if is_valid and val_path == ':auto':
		split_test_path = test_path.split('/')
		val_filename = split_test_path[-1].replace('-test', '-val')
		val_path = '/'.join(split_test_path[:-1]) + '/' + val_filename

	if is_valid and experiment_name is None:
		experiment_name = '.'.join(model_path.split('/')[-1].split('.')[:-1])
	
	if is_valid and single_rater >= 0:
		individual_ratings = True
	
	options = Options(model_path, test_path, train_path, val_path, val_freq, epochs, batch_size, learning_rate, decay_rate, decay_epochs, decay_steps, output_path, output_embedding_path, input_formula, new_model_params, embedding_type, k_tau, verbose, experiment_name, individual_ratings, single_rater, seed, categorical, is_valid)
	return options

def validate_options(options):
	if not options.is_valid:
		return False
	return True

def load_embeddings(input_path, e_type='vilbert'):
	with open(input_path, 'rb') as in_file:
		if e_type in ('vilbert', 'clip'):
			embeddings = deserialize(in_file)
		elif e_type == 'msgit':
			embeddings = ge_deserialize(in_file)
		else:
			print(f"Unsupported embedding type: {e_type}")
			return None
		# vilbert_embeddings = pickle.load(in_file, encoding="latin1")
	return embeddings

def apply_input_formula(input_formula, img_emb, cap_emb, emb=None):
	if type(input_formula) is not list:
		print(f"Error applying input formula: Expected a list, got {input_formula}")
		return None
	
	if len(input_formula) == 0:
		print(f"Error applying input formula: Got empty list")
		return None
	
	if len(input_formula) == 1:
		if input_formula[0] == 'img':
			return img_emb
		elif input_formula[0] == 'cap':
			return cap_emb
		elif input_formula[0] == 'emb':
			return emb
		else:
			print(f"Error applying input formula: Expected 'img', 'cap', or 'emb', got {input_formula}")
			return None
	
	if len(input_formula) == 2:
		shape = input_formula[1]
		if type(shape) is not tuple:
			print(f"Error applying input formula: Expected a tuple, got {shape}")
			return None
		if input_formula[0] == 'img':
			return img_emb.reshape(shape)
		elif input_formula[0] == 'cap':
			return cap_emb.reshape(shape)
		elif input_formula[0] == 'emb':
			return emb.reshape(shape)
		else:
			print(f"Error applying input formula: Expected 'img', 'cap', or 'emb', got {input_formula[0]}")
			return None
	
	result = apply_input_formula(input_formula[0], img_emb, cap_emb, emb)
	if result is None:
		return None
	
	pos = 1
	while pos < len(input_formula) - 1:
		op = input_formula[pos]
		if type(op) is not list:
			print(f"Error applying input formula: Expected a list containing an operator, got {op}")
			return None
		pos += 1
		rhs = apply_input_formula(input_formula[pos], img_emb, cap_emb, emb)
		if rhs is None:
			return None
		if op[0] == '+':
			result = result + rhs
		elif op[0] == '-':
			result = result - rhs
		elif op[0] == '*':
			result = result * rhs
		elif op[0] == '/':
			result = result / rhs
		elif op[0] == '.':
			if len(op) == 1:
				result = np.concatenate((result, rhs))
			elif len(op) == 2 and type(op[1]) == dict:
				result = np.concatenate((result, rhs), **op[1])
		else:
			print(f"Error applying input formula: Expected an operator, got {op[0]}")
			return None
		pos += 1
	if pos != len(input_formula):
		print(f"Error applying input formula: Invalid number of elements in list: {len(input_formula)}")
		return None
	return result

def extract_inputs(embeddings, e_type='vilbert', input_formula=None, individual_ratings=False, single_rater=-1):
	if e_type in ('vilbert', 'clip'):
		if input_formula is None:
			aggregate_input_list = [np.concatenate((obj.image_embedding, obj.caption_embedding)).tolist() for obj in embeddings]
		else:
			aggregate_input_list = [apply_input_formula(input_formula, obj.image_embedding, obj.caption_embedding).tolist() for obj in embeddings]
	elif e_type == 'msgit':
		if input_formula is None:
			aggregate_input_list = [obj.embedding.tolist() for obj in embeddings]
		else:
			aggregate_input_list = [apply_input_formula(input_formula, None, None, obj.embedding).tolist() for obj in embeddings]
	else:
		print(f"Unsupported embedding type: {e_type}")
		return None
	if individual_ratings and single_rater < 0:
		inputs_list = []
		for emb, obj in zip(aggregate_input_list, embeddings):
			inputs_list.extend(emb for r in obj.ratings)
	else:
		inputs_list = aggregate_input_list
	inputs = np.array(inputs_list)
	return inputs
	
def extract_labels(embeddings, individual_ratings=False, single_rater=-1, categorical=False):
	if single_rater >= 0:
		# Take the last rater if single_rater index is out of bounds
		ratings_list = [obj.ratings[single_rater] if single_rater < len(obj.ratings) else obj.ratings[-1] for obj in embeddings]
	elif individual_ratings:
		ratings_list = []
		for obj in embeddings:
			ratings_list.extend(obj.ratings)
	else:
		ratings_list = [sum(obj.ratings)/len(obj.ratings) if len(obj.ratings) > 0 else 0 for obj in embeddings]
	if categorical:
		# TODO: If our ratings ever want to fall outside of 1 to 5, we need to update this!
		ratings_list = [min(max(round(r) - 1, 0), 4) for r in ratings_list]
	ratings = np.array(ratings_list)
	return ratings

def build_model(model_params, input_shape, epochs, learning_rate, decay_rate, decay_steps, categorical=False, verbose=True):
	model = Sequential()
	model.add(InputLayer(input_shape=input_shape))
	for layer_params in model_params:
		layer_type = layer_params[0]
		if layer_type == 'Dense':
			activation = None
			if len(layer_params) == 3:
				activation = layer_params[2]
			model.add(Dense(layer_params[1], activation=activation))
		elif layer_type == 'Reshape':
			model.add(Reshape(layer_params[1]))
		elif layer_type == 'BatchNormalization':
			model.add(BatchNormalization())
		elif layer_type == 'Conv2D':
			activation = None
			if len(layer_params) == 3:
				activation = layer_params[2]
			shape = layer_params[1][1]
			if len(layer_params[1]) == 3:
				shape = tuple(layer_params[1][1:])
			model.add(Conv2D(layer_params[1][0], shape, activation=activation))
		elif layer_type == 'Dropout':
			model.add(Dropout(layer_params[1]))
		elif layer_type == 'MaxPool2D':
			shape = layer_params[1]
			if len(layer_params[1]) == 1:
				shape = layer_params[1][0]
			model.add(MaxPool2D(shape))
		elif layer_type == 'Flatten':
			model.add(Flatten())
		elif layer_type == 'Activation':
			model.add(Activation(layer_params[1]))
		else:
			print(f"Error building model: Unsupported layer type '{layer_type}'")
			return None
	if decay_rate > 0 and decay_steps > 0:
		opt = Adam(learning_rate=ExponentialDecay(learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, staircase=True))
	else:
		opt = Adam(learning_rate=learning_rate)
	
	if categorical:
		loss_func = SparseCategoricalCrossentropy();
		model.compile(optimizer=opt, loss=loss_func, metrics=['accuracy'])
	else:
		loss_func = 'mse';
		model.compile(optimizer=opt, loss=loss_func, metrics=['mse', 'mae'])
	if verbose:
		print(model.summary())
	return model

class KendallTauMetric(Callback):
	def __init__(self, k_tau, writer, individual_ratings=False, verbose=False):
		super().__init__()
		self.k_tau = k_tau
		self.writer = writer
		self.individual_ratings = individual_ratings
		self.verbose = verbose

	def on_epoch_end(self, epoch, logs={}):
		if 'val_loss' in logs:
			for name, embeddings, inputs, labels in self.k_tau:
				predictions = self.model.predict(inputs)
				
				if self.individual_ratings:
					kt = kendalltau(predictions, labels, variant='c')
				else:
					kt = kendalltau_individual(embeddings, predictions)

				if self.verbose:
					print(f"Kendall-Tau({name}):", kt)

				with self.writer.as_default():
					summary.scalar(f"kendall_tau_{name}", kt.correlation, step=epoch)

def train_model(model, epochs, batch_size, inputs, labels, experiment_name, val_data=None, val_freq=1, k_tau=[], individual_ratings=False, verbose=False):
	experiment_log_dir = LOG_DIR + "/" + experiment_name

	try:
		os.mkdir(experiment_log_dir)
	except:
		pass

	tensorboard = TensorBoard(log_dir=experiment_log_dir, histogram_freq=1)
	tensorboard.set_model(model)
	if len(k_tau) > 0:
		# Set up tensorboard summary_writer
		kt_writer = tf.summary.create_file_writer(experiment_log_dir + "/validation")
		kendall_tau_callback = KendallTauMetric(k_tau, kt_writer, individual_ratings, verbose)

	if val_data is not None and len(k_tau) > 0:
		hist = model.fit(inputs, labels, batch_size=batch_size, epochs=epochs, validation_data=val_data, validation_freq=val_freq, callbacks=[tensorboard, kendall_tau_callback])
	elif val_data is not None:
		hist = model.fit(inputs, labels, batch_size=batch_size, epochs=epochs, validation_data=val_data, validation_freq=val_freq, callbacks=[tensorboard])
	else:
		hist = model.fit(inputs, labels, batch_size=batch_size, epochs=epochs, callbacks=[tensorboard])
	return hist

def kendalltau_individual(vilbert_embeddings, predictions, variant='c'):
	individual_predictions = []
	individual_consensus_ratings = []
	for ve, prediction in zip(vilbert_embeddings, predictions):
		individual_ratings = ve.ratings
		individual_consensus_ratings.extend(individual_ratings)
		individual_predictions.extend(prediction for x in individual_ratings)
	if variant != 'b':
		return kendalltau(individual_predictions, individual_consensus_ratings, variant=variant)
	else:
		return kendalltau(individual_predictions, individual_consensus_ratings)



def main():
	options = parse_args()
	
	if not validate_options(options):
		return
	
	# print(options)
	if options.verbose:
		print("Input Formula:")
		print(options.input_formula)
		print()
		print("New Model parameters:")
		print(options.new_model_params)
		print()
	# if options.input_formula is not None:
	# 	print("Input Formula Test")
	# 	print(apply_input_formula(options.input_formula, np.array([1, 2]), np.array([10, 20])))
	# 	print()

	if options.seed is not None:
		if options.verbose:
			print("Setting random seed to ", str(options.seed))
		# tf.keras.utils.set_random_seed(options.seed)
		random.seed(options.seed)
		np.random.seed(options.seed)
		tf.random.set_seed(options.seed)

	# use mixed precision
	# policy = mixed_precision.Policy('mixed_float16')
	# mixed_precision.set_global_policy(policy)
	# print("Using mixed_precision as global default policy.")
	# print('Compute dtype: %s' % policy.compute_dtype)
	# print('Variable dtype: %s' % policy.variable_dtype)
	
	test_embeddings = load_embeddings(options.test_path, options.embedding_type)
	test_inputs = extract_inputs(test_embeddings, options.embedding_type, options.input_formula, options.individual_ratings, options.single_rater)
	test_labels = extract_labels(test_embeddings, options.individual_ratings, options.single_rater, options.categorical)
	if options.train_path != '':
		train_embeddings = load_embeddings(options.train_path, options.embedding_type)
		train_inputs = extract_inputs(train_embeddings, options.embedding_type, options.input_formula, options.individual_ratings, options.single_rater)
		train_labels = extract_labels(train_embeddings, options.individual_ratings, options.single_rater, options.categorical)
		if options.val_path != '':
			val_embeddings = load_embeddings(options.val_path, options.embedding_type)
			val_inputs = extract_inputs(val_embeddings, options.embedding_type, options.input_formula, options.individual_ratings, options.single_rater)
			val_labels = extract_labels(val_embeddings, options.individual_ratings, options.single_rater, options.categorical)
	
	if options.new_model_params is not None:
		decay_steps = options.decay_steps
		if decay_steps is None:
			steps_per_epoch = int((len(train_labels) +options.batch_size - 1) / options.batch_size)
			decay_steps = options.decay_epochs * steps_per_epoch
		vicr_model = build_model(options.new_model_params, test_inputs[0].shape, options.epochs, options.learning_rate, options.decay_rate, decay_steps, options.categorical, options.verbose)
	else:
		vicr_model = tf.keras.models.load_model(options.model_path)
	
	if vicr_model is None:
		print("Error building or loading model. Aborting")
		return
	
	k_tau_datasets = []
	for path in options.k_tau:
		name = path
		if path == ':test':
			embeddings = test_embeddings
			inputs = test_inputs
			labels = test_labels
		elif path == ':train':
			embeddings = train_embeddings
			inputs = train_inputs
			labels = train_labels
		elif path == ':val':
			embeddings = val_embeddings
			inputs = val_inputs
			labels = val_labels
		else:
			name = '.'.join(path.split('/')[-1].split('.')[:-1])
			embeddings = load_embeddings(path, options.embedding_type)
			inputs = extract_inputs(embeddings, options.embedding_type, options.input_formula, options.individual_ratings, options.single_rater)
			labels = extract_labels(embeddings, options.individual_ratings, options.single_rater, options.categorical)
		k_tau_datasets.append((name, embeddings, inputs, labels))
	
	if options.train_path != '':
		if options.val_path != '':
			train_model(vicr_model, options.epochs, options.batch_size, train_inputs, train_labels, options.experiment_name, val_data=(val_inputs, val_labels), val_freq=options.val_freq, k_tau=k_tau_datasets, individual_ratings=options.individual_ratings, verbose=options.verbose)
		else:
			train_model(vicr_model, options.epochs, options.batch_size, train_inputs, train_labels, options.experiment_name, k_tau=k_tau_datasets, individual_ratings=options.individual_ratings, verbose=options.verbose)
	
	vicr_model.save(options.model_path)
	
	predictions = vicr_model.predict(test_inputs)
	
	if options.categorical:
		predictions = np.argmax(predictions, axis=1)
	
	if options.verbose:
		print(predictions)
		print(test_labels)
	
	for name, embeddings, inputs, labels in k_tau_datasets:
		kt_predictions = vicr_model.predict(inputs)
		if options.individual_ratings:
			kt = kendalltau(kt_predictions, labels, variant='c')
		else:
			kt = kendalltau_individual(embeddings, kt_predictions)
		print(f"Kendall-Tau({name}):", kt)
	
	with open(options.output_path, 'w', newline='', encoding='utf-8') as outcsvfile:
		writer = csv.writer(outcsvfile)
		writer.writerow(['image', 'caption', 'rating'])
		for em, prediction in zip(test_embeddings, predictions):
			if options.categorical:
				writer.writerow([em.image, em.caption, prediction])
			else:
				writer.writerow([em.image, em.caption, prediction[0]])

	if options.embedding_type == 'vilbert':
		# also write a new embedding file with the predictions 
		result = []
		for em, prediction in zip(test_embeddings, predictions):
			if options.categorical:
				result.append(Image_Caption_Embedding(em.image, em.caption, em.image_embedding, em.caption_embedding, [prediction]))
			else:
				result.append(Image_Caption_Embedding(em.image, em.caption, em.image_embedding, em.caption_embedding, prediction))
					
		
		with open(options.output_embedding_path, 'wb') as out_file:
			serialize(result, out_file)
	
if __name__ == "__main__":
	main()
