"""Evaluate embeddings.
"""
import argparse
import os
import pickle as pkl
import traceback
from typing import Callable, Sequence, Optional

import numpy as np
import tqdm
import sklearn
import sklearn.linear_model
import filelock

import paramrepulsor.utils.data as data_utils
import eval_helper
import config


def strings_to_integers(array):
    # Flatten the array and get unique strings
    unique_strings = np.unique(array)

    # Create a dictionary that maps each string to an index
    string_to_int = {string: idx for idx, string in enumerate(unique_strings)}
    
    # Convert the numpy array of strings to their corresponding integer values
    vectorized_map = np.vectorize(string_to_int.get)
    int_array = vectorized_map(array)
    
    return int_array


def svm_eval(embedding, label, X_orig, **kwargs):
    return eval_helper.faster_svm_eval(embedding, label)


def knn_eval(embedding, label, X_orig, **kwargs):
    return eval_helper.knn_eval(embedding, label, n_neighbors=10)


def knn_large_eval(embedding, label, X_orig, **kwargs):
    return eval_helper.knn_eval(embedding, label, n_neighbors=100)


def nkr_eval(embedding, label, X_orig, **kwargs):
    return eval_helper.neighbor_kept_ratio_eval_new(X=X_orig, X_new=embedding,)


def rte_eval(embedding, label, X_orig, **kwargs):
    return eval_helper.random_triplet_eval(X=X_orig, X_new=embedding,)


def spc_eval(embedding, label, X_orig, **kwargs):
    dist_high, dist_low, corr, pval = eval_helper.spearman_correlation_eval(X=X_orig, X_new=embedding)
    return corr


def ckc_eval(embedding, label, X_orig, **kwargs):
    num_categories = len(np.unique(label))
    k = min((num_categories + 2) // 4, 10) # maximum of 10
    dist_high, dist_low, corr, pval = eval_helper.centroid_corr_eval(X=X_orig, X_new=embedding, y=label, k=k)
    return corr


def ckn_eval(embedding, label, X_orig, **kwargs):
    num_categories = len(np.unique(label))
    k = min((num_categories + 2) // 4, 10) # maximum of 10
    return eval_helper.centroid_knn_eval(X=X_orig, X_new=embedding, y=label, k=k)


def mre_eval(
    embedding: np.ndarray,
    label: np.ndarray,
    X_orig: np.ndarray,
    **kwargs
):
    """Min Reconstruction Error score proposed in Amir et al.."""
    reg = sklearn.linear_model.LinearRegression()
    reg.fit(embedding, X_orig)
    X_rec = reg.predict(embedding)
    error = np.linalg.norm(X_orig - X_rec)
    return error


def fill_result(metric_name: str,
                reducer_name: str,
                func: Callable,
                result_path: str,
                embeddings: Sequence[np.ndarray],
                force_update: bool=False,
                *args,
                **kwargs):
    """A general function that fills the result with the metric.
    """
    # In case there is just one embedding, fix its shape.
    if len(embeddings.shape) == 2:
        embeddings = [embeddings]
    num_embeddings = len(embeddings)
    result_name = f"{reducer_name}_{metric_name}"

    # If there is a force update, then have to evaluate all embeddings again
    # otherwise, check if there is existing results.
    result = get_result(result_path=result_path)
    if force_update or result_name not in result:
        category_results = []
    else:
        category_results = result[result_name]
    num_finished = len(category_results)
    for i in range(num_finished, num_embeddings):
        embedding = embeddings[i]
        category_results.append(func(*args, embedding=embedding, **kwargs))

    lock = filelock.FileLock(f"{result_path}.lock")
    with lock:
        result = get_result(result_path=result_path)
        result[result_name] = category_results
        pkl.dump(result, open(result_path, "wb"))
    return result


def get_result(result_path):
    if os.path.exists(result_path):
        result = pkl.load(open(result_path, "rb"))
    else:
        result = {}
    return result


def eval_dataset(reducer: str,
                 dataset: str,
                 update_metric_name: Optional[str],
                 force: bool):
    """Evaluate the embedding generated by different denominators.
    """
    X, y = data_utils.data_prep(dataset)
    y = strings_to_integers(y)

    # Load generated embedding
    force_embedding_path = os.path.join(
        "./eval_outputs",
        f"{reducer}",
        f"{dataset}.npy")
    embeddings = np.load(force_embedding_path, allow_pickle=True)

    # Load existing results
    # Results are organized by the dataset.
    result_path = os.path.join(
        "./results",
        f"{dataset}.pkl")
    if os.path.exists(result_path):
        result = pkl.load(open(result_path, "rb"))
    else:
        result = {}
    kwargs = {"label": y, "X_orig": X}
    if update_metric_name is None:
        for metric_name, metric_func in tqdm.tqdm(METRIC_NAME_FUNC_DICT.items()):
            print(f"{reducer}, {dataset}, {metric_name}", flush=True)
            try:
                result = fill_result(
                    metric_name,
                    reducer,
                    metric_func,
                    result_path,
                    embeddings,
                    force_update=force,
                    **kwargs
                )
                
            except Exception as e:
                print(f"An error occurred during handling of {dataset}|{reducer}: {e}")
    else:
        print(f"{reducer}, {dataset}, {update_metric_name}", flush=True)
        try:
            result = fill_result(
                update_metric_name,
                reducer,
                METRIC_NAME_FUNC_DICT[update_metric_name],
                result_path,
                embeddings,
                force_update=True,
                **kwargs
            )
        except Exception as e:
            print(f"An error occurred during handling of {dataset}|{reducer}: {e}")
    pkl.dump(result, open(result_path, "wb"))
    print(f"{dataset}|{reducer} finished, results saved at {result_path}")


METRIC_NAME_FUNC_DICT = {
    "svm_acc": svm_eval,
    "knn_acc": knn_eval, 
    "random_triplet_acc": rte_eval,
    "dist_spearman_corr": spc_eval,
    "centroid_corr": ckc_eval,
}


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--updatemetric", type=str, default=None)
    # Type of reducer being evaluated
    parser.add_argument("--reducer", type=str, required=True)
    parser.add_argument("--force", help="Force an update", action="store_true")
    return parser.parse_args()


def main(args):
    dataset_list = [args.dataset]   # Substitute this with multiple datasets
    has_error = False
    for dataset in dataset_list:
        try:
            eval_dataset(
                args.reducer,
                dataset,
                args.updatemetric,
                args.force,
            )
        except Exception as e:
            print(f"An error occurred during handling of {dataset}|{args.reducer}: {e}")
            has_error = True
            traceback.print_exc()
    if has_error:
        raise ValueError("An error has occurred during the experiment, check the log for more details.")


if __name__ == "__main__":
    args = get_args()
    main(args)
