from absl import app
from absl import flags
from absl import logging
import csv
import importlib
import numpy as np
import os.path as path
import random
from sklearn.model_selection import train_test_split
import time

from transformations.reader.matrix import test_argument_and_file, load_and_log
from transformations.reader.folder import get_paths_dict, test_argument_and_path, read_paths_to_matrix
import transformations.tfhub_module as tfhub_module
import transformations.torchhub_model as torchhub_model
import transformations.label_noise as label_noise
import transformations.pca as pca
import transformations.nca as nca
import transformations.random_proj as random_proj
import methods.knn as knn
import methods.ghp as ghp
import methods.kde as kde
import methods.onenn as onenn
import transformations.aug_tfhub_module as aug_tfhub_module

FLAGS = flags.FLAGS

flags.DEFINE_enum("variant", "matrices", ["matrices", "folder"], "Run the tool from matrices or folder")
flags.DEFINE_string("path", ".", "Path to the matrices directory")
flags.DEFINE_string("features_train", None, "Name of the train features numpy matrix exported file (npy)")
flags.DEFINE_string("features_test", None, "Name of the test features numpy matrix exported file (npy)")
flags.DEFINE_string("labels_train", None, "Name of the train labels numpy matrix exported file (npy)")
flags.DEFINE_string("labels_test", None, "Name of the test labels numpy matrix exported file (npy)")
flags.DEFINE_string("folder_train_path", None, "Path to the train folder with the images in the subfolders")
flags.DEFINE_string("folder_test_path", None, "Path to the test folder with the images in the subfolders")

flags.DEFINE_list("train_subsamples", None, "Trainingset subsamples")
flags.DEFINE_integer("subsamples_runs", 0, "Number of runs for training subsamples")

flags.DEFINE_list("noise_levels", None, "Run at different noise levels")
flags.DEFINE_integer("noise_runs", 0, "Number of runs for different noise levels")

flags.DEFINE_integer("augment_times", 1, "Times to augment the full training dataset")
flags.DEFINE_integer("augment_runs", 0, "Runs for differnt augment the full training dataset")

flags.DEFINE_string("output_file", None, "File to write the output in CSV format (including headers)")
flags.DEFINE_bool("output_overwrite", True, "Writes (if True) or appends (if False) to the specified output file if any")

flags.DEFINE_enum("method", None, ["knn", "knn_loo", "ghp", "kde_knn_loo", "kde", "onenn"], "Method to estimate the bayes error (results in either 1 value or a lower and upper bound)")

flags.DEFINE_list("transformations", [], "List of transformations (can be empty to be applied in that order (using the matrices as input)")


def _get_transform_fns():
    fns = []
    for t in FLAGS.transformations:
        val = t.strip().lower()
        if val == "tfhub_module":
            fns.append((val, tfhub_module.apply))
        elif val == "torchhub_model":
            fns.append((val, torchhub_model.apply))
        elif val == "label_noise":
            fns.append((val, label_noise.apply))
        elif val == "pca":
            fns.append((val, pca.apply))
        elif val == "nca":
            fns.append((val, nca.apply))
        elif val == "random_proj":
            fns.append((val, random_proj.apply))
        else:
            raise app.UsageError("Transformation '{}' is not valid!".format(t))
    
    return fns


def _run_augmentation(orig_train_features, orig_train_labels, train_features, train_labels, samples_train, eval_fn, result_full, noise):

    result_rows = []
    for run in range(FLAGS.augment_runs):
        logging.log(logging.DEBUG, "Start augment run {} out of {}".format(run+1, FLAGS.augment_runs))
        run_start = time.time()

        if FLAGS.output_file:
            rows = [_get_csv_row(k, run, samples_train, noise, v) for k, v in result_full.items()]
            result_rows.extend(rows)

        # Reset
        x_train = None
        y_train = None
        for aug_time in range(2, FLAGS.augment_times+1):

            logging.log(logging.DEBUG, "Start augment {} times for run {} out of {}".format(aug_time, run+1, FLAGS.augment_runs))
            sub_start = time.time()

            aug_x, dim_aug, samples_aug, aug_y = aug_tfhub_module.apply(orig_train_features, orig_train_features.shape[1], orig_train_features.shape[0], orig_train_labels) 
            # Apply transformations if any!
            for i, (_, fn) in enumerate(_get_transform_fns()):
                logging.log(logging.DEBUG, "Apply transformation '{}' to the augmented set".format(i))
                aug_x, _, _, aug_y = fn(aug_x, dim_aug, samples_aug, aug_y)

            x_train = np.concatenate((x_train, aug_x)) if x_train is not None else aug_x
            y_train = np.concatenate((y_train, aug_y)) if y_train is not None else aug_y

            logging.log(logging.DEBUG, "Start estimation with method '{}', subsamples {}/{}, run {}/{}".format(FLAGS.method, aug_time*samples_train, samples_train, run+1, FLAGS.augment_runs))
            start = time.time()
            result = eval_fn(np.concatenate((train_features, x_train)), np.concatenate((train_labels, y_train)))
            end = time.time()
            logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))

            if FLAGS.output_file:
                rows = [_get_csv_row(k, run, aug_time*samples_train, noise, v) for k, v in result.items()]
                result_rows.extend(rows)

            logging.log(logging.INFO, "Run {}/{} - train set ({}/{}): {}".format(run+1, FLAGS.augment_runs, aug_time*samples_train, samples_train, result))

            sub_end = time.time()
            logging.log(logging.DEBUG, "Augment {} times for run {}/{} executed in {} seconds".format(aug_time, run+1, FLAGS.augment_runs, sub_end - sub_start))

        run_end = time.time()
        logging.log(logging.DEBUG, "Run {}/{} executed in {} seconds".format(run+1, FLAGS.augment_runs, run_end - run_start))

    return result_rows
        

def _run_subsamples(train_features, train_labels, samples_train, eval_fn, result_full, noise):

    result_rows = []

    for run in range(FLAGS.subsamples_runs):
        logging.log(logging.DEBUG, "Start subsample run {} out of {}".format(run+1, FLAGS.subsamples_runs))
        run_start = time.time()

        if FLAGS.output_file:
            rows = [_get_csv_row(k, run, samples_train, noise, v) for k, v in result_full.items()]
            result_rows.extend(rows)
        
        # Reset
        x_train = None
        y_train = None
        for sub_train in sorted([int(x) for x in FLAGS.train_subsamples], reverse=True):
            if sub_train < 1 or sub_train >= samples_train:
                raise AttributeError("Subsample {} has to be positive and smaller than the full number of samples!".format(sub_train))

            logging.log(logging.DEBUG, "Start subsample {} for run {} out of {}".format(sub_train, run+1, FLAGS.subsamples_runs))
            sub_start = time.time()
            x_train, _, y_train, _ = train_test_split(train_features if x_train is None else x_train,
                                                      train_labels if y_train is None else y_train,
                                                      test_size = None,
                                                      train_size = sub_train,
                                                      stratify = train_labels if y_train is None else y_train)

            logging.log(logging.DEBUG, "Start estimation with method '{}', subsamples {}/{}, run {}/{}".format(FLAGS.method, sub_train, samples_train, run+1, FLAGS.subsamples_runs))
            start = time.time()
            result = eval_fn(x_train, y_train)
            end = time.time()
            logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))

            if FLAGS.output_file:
                rows = [_get_csv_row(k, run, sub_train, noise, v) for k, v in result.items()]
                result_rows.extend(rows)

            logging.log(logging.INFO, "Run {}/{} - train set ({}/{}): {}".format(run+1, FLAGS.subsamples_runs, sub_train, samples_train, result))

            sub_end = time.time()
            logging.log(logging.DEBUG, "Subsample {} for run {}/{} executed in {} seconds".format(sub_train, run+1, FLAGS.subsamples_runs, sub_end - sub_start))

        run_end = time.time()
        logging.log(logging.DEBUG, "Run {}/{} executed in {} seconds".format(run+1, FLAGS.subsamples_runs, run_end - run_start))

    return result_rows


def _run_subsamples_folder_knn(train_paths, test_paths, samples_train, result_full, noise):

    num_classes = np.unique(list(train_paths.keys())).size
    result_rows = []

    for run in range(FLAGS.subsamples_runs):
        logging.log(logging.DEBUG, "Start subsample run {} out of {}".format(run+1, FLAGS.subsamples_runs))
        run_start = time.time()

        if FLAGS.output_file:
            rows = [_get_csv_row(k, run, samples_train, noise, v) for k, v in result_full.items()]
            result_rows.extend(rows)
        
        # Reset
        x_train = dict(train_paths)
        for sub_train in sorted([int(x) for x in FLAGS.train_subsamples], reverse=True):
            if sub_train < 1 or sub_train >= samples_train:
                raise AttributeError("Subsample {} has to be positive and smaller than the full number of samples!".format(sub_train))

            if sub_train < num_classes or sub_train % num_classes > 0:
                raise AttributeError("Subsample {} has to be devidable by the number of classes {}!".format(sub_train, num_classes))
            per_class = sub_train // num_classes

            for k in train_paths.keys():
                x_train[k] = random.sample(x_train[k], per_class)

            logging.log(logging.DEBUG, "Start subsample {} for run {} out of {}".format(sub_train, run+1, FLAGS.subsamples_runs))
            sub_start = time.time()
            logging.log(logging.DEBUG, "Start estimation with method '{}', subsamples {}/{}, run {}/{}".format(FLAGS.method, sub_train, samples_train, run+1, FLAGS.subsamples_runs))
            start = time.time()
            result = knn.eval_from_folder(x_train, test_paths, read_paths_to_matrix)
            end = time.time()
            logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))

            if FLAGS.output_file:
                rows = [_get_csv_row(k, run, sub_train, noise, v) for k, v in result.items()]
                result_rows.extend(rows)

            logging.log(logging.INFO, "Run {}/{} - train set ({}/{}): {}".format(run+1, FLAGS.subsamples_runs, sub_train, samples_train, result))

            sub_end = time.time()
            logging.log(logging.DEBUG, "Subsample {} for run {}/{} executed in {} seconds".format(sub_train, run+1, FLAGS.subsamples_runs, sub_end - sub_start))

        run_end = time.time()
        logging.log(logging.DEBUG, "Run {}/{} executed in {} seconds".format(run+1, FLAGS.subsamples_runs, run_end - run_start))

    return result_rows


def _get_csv_row(variant, run, samples, noise, results):
    return {'method': FLAGS.method,
            'variant': variant,
            'run': run,
            'samples': samples,
            'noise': noise,
            'results': results}

def _write_result(rows):
    writeheader = False
    if FLAGS.output_overwrite or not path.exists(FLAGS.output_file):
        writeheader = True
    with open(FLAGS.output_file, mode='w+' if FLAGS.output_overwrite else 'a+') as f:
        fieldnames = ['method', 'variant', 'run', 'samples', 'noise', 'results']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if writeheader:
            writer.writeheader()
        for r in rows:
            writer.writerow(r)


def estimate_from_split_matrices(eval_fn):
    test_argument_and_file(FLAGS.path, "features_train")
    test_argument_and_file(FLAGS.path, "features_test")
    test_argument_and_file(FLAGS.path, "labels_train")
    test_argument_and_file(FLAGS.path, "labels_test")

    train_features, dim_train, samples_train = load_and_log(FLAGS.path, "features_train")
    test_features, dim_test, samples_test = load_and_log(FLAGS.path, "features_test")
    if dim_test != dim_train:
        raise AttributeError("Train and test features do not have the same dimension!")
    train_labels, dim, samples_train_labels = load_and_log(FLAGS.path, "labels_train")
    if dim != 1:
        raise AttributeError("Train labels file does not point to a vector!")
    if samples_train_labels != samples_train:
        raise AttributeError("Train features and labels files does not have the same amount of samples!")
    test_labels, _, samples_test_labels = load_and_log(FLAGS.path, "labels_test")
    if dim != 1:
        raise AttributeError("Test labels file does not point to a vector!")
    if samples_test_labels != samples_test:
        raise AttributeError("Test features and labels files does not have the same amount of samples!")

    orig_train_features = train_features
    orig_train_labels = train_labels

    for i, (val, fn) in enumerate(_get_transform_fns()):
        logging.log(logging.DEBUG, "Apply transformation '{}' to the train set".format(i))
        new_train_features, new_dim_train, new_samples_train, new_train_labels = fn(train_features, dim_train, samples_train, train_labels)
        logging.log(logging.DEBUG, "Apply transformation '{}' to the test set".format(i))
        if val == "pca":
            test_features, dim_test, samples_test, test_labels = fn(test_features, dim_test, samples_test, test_labels, train_features)
        elif val == "nca":
            test_features, dim_test, samples_test, test_labels = fn(test_features, dim_test, samples_test, test_labels, train_features, train_labels)
        else:
            test_features, dim_test, samples_test, test_labels = fn(test_features, dim_test, samples_test, test_labels)
        train_features = new_train_features
        dim_train = new_dim_train
        samples_train = new_samples_train
        train_labels = new_train_labels

    logging.log(logging.DEBUG, "Start full estimation with method '{}'".format(FLAGS.method))
    start = time.time()
    result_full = eval_fn(train_features, test_features, train_labels, test_labels)
    end = time.time()
    logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))
    logging.log(logging.INFO, "Full train and test set: {}".format(result_full))

    if FLAGS.train_subsamples and FLAGS.subsamples_runs > 0:
        rows = _run_subsamples(train_features, train_labels, samples_train, lambda x_train, y_train: eval_fn(x_train, test_features, y_train, test_labels), result_full, 0.0)
        if FLAGS.output_file:
            _write_result(rows)

    elif FLAGS.augment_times > 1 and FLAGS.augment_runs > 0:
        rows = _run_augmentation(orig_train_features, orig_train_labels, train_features, train_labels, samples_train, lambda x_train, y_train: eval_fn(x_train, test_features, y_train, test_labels), result_full, 0.0)
        if FLAGS.output_file:
            _write_result(rows)

    elif FLAGS.noise_levels and FLAGS.noise_runs > 0:
        result_rows = []
        for run in range(FLAGS.noise_runs):

            if FLAGS.output_file:
                rows = [_get_csv_row(k, run, samples_train, 0.0, v) for k, v in result_full.items()]
                result_rows.extend(rows)

            logging.log(logging.DEBUG, "Start noisy run {} out of {}".format(run+1, FLAGS.noise_runs))
            run_start = time.time()
            for noise_level in [float(x) for x in FLAGS.noise_levels]:
                if noise_level > 1.0 or noise_level <= 0.0:
                    raise AttributeError("Noise level {} has to be bigger than 0 and not larger than 1!".format(noise_level))
                logging.log(logging.DEBUG, "Start noise level {} for run {} out of {}".format(noise_level, run+1, FLAGS.noise_runs))
                noise_start = time.time()

                # flip labels test and train
                flipped_train_labels = label_noise.random_flip(train_labels, samples_train, noise_level, copy=True)
                flipped_test_labels = label_noise.random_flip(test_labels, samples_test, noise_level, copy=True)
                
                # run method
                logging.log(logging.DEBUG, "Start full estimation with method '{}', noise level {}, run {}/{}".format(FLAGS.method, noise_level, run+1, FLAGS.noise_runs))
                start = time.time()
                result = eval_fn(train_features, test_features, flipped_train_labels, flipped_test_labels)
                end = time.time()
                logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))
                logging.log(logging.INFO, "Run {}/{} - noise level {}: {}".format(run+1, FLAGS.noise_runs, noise_level, result))

                if FLAGS.output_file:
                    rows = [_get_csv_row(k, run, samples_train, noise_level, v) for k, v in result.items()]
                    result_rows.extend(rows)

                noise_end = time.time()
                logging.log(logging.DEBUG, "Noise level {} for run {}/{} executed in {} seconds".format(noise_level, run+1, FLAGS.noise_runs, noise_end - noise_start))
            run_end = time.time()
            logging.log(logging.DEBUG, "Run {}/{} executed in {} seconds".format(run+1, FLAGS.noise_runs, run_end - run_start))

        if FLAGS.output_file:
            _write_result(result_rows)

    elif FLAGS.output_file:
        rows = [_get_csv_row(k, 0, samples_train, 0.0, v) for k, v in result_full.items()]
        _write_result(rows)


def estimate_from_single_matrix(eval_fn):
    test_argument_and_file(FLAGS.path, "features_train")
    test_argument_and_file(FLAGS.path, "labels_train")

    train_features, dim_train, samples_train = load_and_log(FLAGS.path, "features_train")
    train_labels, dim, samples_train_labels = load_and_log(FLAGS.path, "labels_train")
    if dim != 1:
        raise AttributeError("Train labels file does not point to a vector!")
    if samples_train_labels != samples_train:
        raise AttributeError("Train features and labels files does not have the same amount of samples!")

    for i, (_, fn) in enumerate(_get_transform_fns()):
        logging.log(logging.DEBUG, "Apply transformation '{}' to the train set".format(i))
        train_features, dim_train, samples_train, train_labels = fn(train_features, dim_train, samples_train, train_labels)

    logging.log(logging.DEBUG, "Start full estimation with method '{}'".format(FLAGS.method))
    start = time.time()
    result_full = eval_fn(train_features, train_labels) 
    end = time.time()
    logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))
    logging.log(logging.INFO, "Full train set: {}".format(result_full))

    if FLAGS.train_subsamples and FLAGS.subsamples_runs > 0:
        rows = _run_subsamples(train_features, train_labels, samples_train, lambda x_train, y_train: eval_fn(x_train, y_train), result_full, 0.0)
        if FLAGS.output_file:
            _write_result(rows)

    elif FLAGS.noise_levels and FLAGS.noise_runs > 0:
        result_rows = []
        for run in range(FLAGS.noise_runs):

            if FLAGS.output_file:
                rows = [_get_csv_row(k, run, samples_train, 0.0, v) for k, v in result_full.items()]
                result_rows.extend(rows)

            logging.log(logging.DEBUG, "Start noisy run {} out of {}".format(run+1, FLAGS.noise_runs))
            run_start = time.time()
            for noise_level in [float(x) for x in FLAGS.noise_levels]:
                if noise_level > 1.0 or noise_level <= 0.0:
                    raise AttributeError("Noise level {} has to be bigger than 0 and not larger than 1!".format(noise_level))
                logging.log(logging.DEBUG, "Start noise level {} for run {} out of {}".format(noise_level, run+1, FLAGS.noise_runs))
                noise_start = time.time()

                # flip labels test and train
                flipped_train_labels = label_noise.random_flip(train_labels, samples_train, noise_level, copy=True)
                
                # run method
                logging.log(logging.DEBUG, "Start full estimation with method '{}', noise level {}, run {}/{}".format(FLAGS.method, noise_level, run+1, FLAGS.noise_runs))
                start = time.time()
                result = eval_fn(train_features, flipped_train_labels)
                end = time.time()
                logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))
                logging.log(logging.INFO, "Run {}/{} - noise level {}: {}".format(run+1, FLAGS.noise_runs, noise_level, result))

                if FLAGS.output_file:
                    rows = [_get_csv_row(k, run, samples_train, noise_level, v) for k, v in result.items()]
                    result_rows.extend(rows)

                noise_end = time.time()
                logging.log(logging.DEBUG, "Noise level {} for run {}/{} executed in {} seconds".format(noise_level, run+1, FLAGS.noise_runs, noise_end - noise_start))
            run_end = time.time()
            logging.log(logging.DEBUG, "Run {}/{} executed in {} seconds".format(run+1, FLAGS.noise_runs, run_end - run_start))

        if FLAGS.output_file:
            _write_result(result_rows)

    elif FLAGS.output_file:
        rows = [_get_csv_row(k, 0, samples_train, 0.0, v) for k, v in result_full.items()]
        _write_result(rows)


def eval_from_folder_knn():
    test_argument_and_path("folder_train_path")
    test_argument_and_path("folder_test_path")

    # Get all train files path per class for subsampling
    train_paths = get_paths_dict(FLAGS.folder_train_path)
    samples_train = sum([len(v) for k, v in train_paths.items()])

    # Get all test files path (regardless of the classes)
    test_paths = get_paths_dict(FLAGS.folder_test_path)

    logging.log(logging.DEBUG, "Start full estimation with method '{}'".format(FLAGS.method))
    start = time.time()
    # Run knn on all paths for test and train
    result_full = knn.eval_from_folder(train_paths, test_paths, read_paths_to_matrix)
    end = time.time()
    logging.log(logging.DEBUG, "Method '{}' executed in {} seconds".format(FLAGS.method, end - start))
    logging.log(logging.INFO, "Full train set: {}".format(result_full))

    if FLAGS.train_subsamples and FLAGS.subsamples_runs > 0:
        rows = _run_subsamples_folder_knn(train_paths, test_paths, samples_train, result_full, 0.0)
        if FLAGS.output_file:
            _write_result(rows)

    elif FLAGS.noise_levels and FLAGS.noise_runs > 0:
        if FLAGS.output_file:
            rows = [_get_csv_row(k, 0, samples_train, 0.0, v) for k, v in result_full.items()]
            _write_result(rows)
        raise NotImplementedError("Noise levels not implemented for folder knn! No noise output written!")

    elif FLAGS.output_file:
        rows = [_get_csv_row(k, 0, samples_train, 0.0, v) for k, v in result_full.items()]
        _write_result(rows)


def main(argv):

    if FLAGS.train_subsamples and FLAGS.subsamples_runs > 0 and FLAGS.noise_levels and FLAGS.noise_runs > 0:
        raise app.UsageError("Either use training subsample or different noise levels. If you need both, run tool with a label noise transformation instead.")

    if FLAGS.method is None:
        raise app.UsageError("You have to specify the method!")

    if FLAGS.variant == "matrices":
        if FLAGS.method == "knn":
            estimate_from_split_matrices(knn.eval_from_matrices)
        elif FLAGS.method == "knn_loo":
            estimate_from_single_matrix(knn.eval_from_matrix_loo)
        elif FLAGS.method == "ghp":
            estimate_from_single_matrix(ghp.eval_from_matrix)
        elif FLAGS.method == "kde_knn_loo":
            estimate_from_single_matrix(kde.eval_from_matrix_knn_loo)
        elif FLAGS.method == "onenn":
            estimate_from_single_matrix(onenn.eval_from_matrix_onenn)
        elif FLAGS.method == "kde":
            estimate_from_single_matrix(kde.eval_from_matrix_kde)
        else:
            raise NotImplementedError("Method module for 'matrices' not yet implemented!")
    elif FLAGS.variant == "folder":
        if FLAGS.method == "knn":
            eval_from_folder_knn()
        else:
            raise NotImplementedError("Method module for 'folder' not yet implemented!")
    else:
        raise NotImplementedError("Variant '{}' not yet implemented! Come back later :)".format(FLAGS.variant))


if __name__ == "__main__":
    app.run(main)
