"""
Like extract_exemplars.py, but this extracts reference images from the top and query
images from the given percentile range.
"""

import argparse
import os
import pickle

import numpy as np
import pandas as pd
from mis_utils import (get_available_similarity_functions,
                       prepare_machine_interpretability_score)
from sg_utils import (accuracies, extract_stimuli_for_layer_units,
                      get_default_device, get_model_layers, load_model)
from tqdm import tqdm


def process(args):
    """
    Extracts all images for the specified combination of model, layer and neurons.
    :param args: the CLI-arguments
    """

    # set random seed
    np.random.seed(args.seed)

    device = get_default_device()

    # get all the layers we need and map them to their respective units
    model = load_model(args.model, device=device)
    layers = get_model_layers(model, get_modules=True)

    results = pd.DataFrame(columns=["layer", "unit", "mis", "mis_confidence"])

    model_name = args.model = (
        args.model.replace("/", "_").replace(":", "_").replace("__", "_")
    )

    if args.model_checkpoint:
        model_checkpoint = args.model_checkpoint.replace("//", "/")
        model_name_with_checkpoint = (
            model_name
            + "_"
            + "_".join(model_checkpoint.split("/")[-2:])
            .replace(".pth", "")
            .replace(".pth.tar", "")
        )
    else:
        model_name_with_checkpoint = model_name

    output_fn = os.path.join(
        args.output_dir,
        f"machine_interpretability_{args.similarity_function}_natural_{model_name_with_checkpoint}.pkl",
    )

    if os.path.exists(output_fn):
        print(f"Skipping as output file ({output_fn}) already exists.")
        return

    compute_machine_interpretability_score = prepare_machine_interpretability_score(
        args.similarity_function, args.similarity_function_args
    )

    unique_layer_types = dict()
    layer_pbar = tqdm(layers)
    for layer in layer_pbar:
        layer_pbar.set_description(f"Layer {layer}")
        unit_pbar = tqdm(
            extract_stimuli_for_layer_units(
                model_name_with_checkpoint,
                layer,
                args.num_batches,
                args.start_min,
                args.stop_min,
                args.start_max,
                args.stop_max,
                args.csv,
                args.activations_root,
            ),
            leave=False,
            position=1,
        )
        layer_results = []
        best_mis, best_mis_confidence = 0, 0
        worst_mis, worst_mis_confidence = 1, 1
        for unit, batch_filenames in unit_pbar:
            mis, mis_confidence = compute_machine_interpretability_score(
                batch_filenames
            )
            if mis > best_mis:
                best_mis = mis
                best_mis_confidence = mis_confidence
            if mis < worst_mis:
                worst_mis = mis
                worst_mis_confidence = mis_confidence

            layer_type = type(layers[layer]).__name__
            if layer_type not in unique_layer_types:
                unique_layer_types[layer_type] = len(unique_layer_types)

            layer_results.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "mis": mis,
                    "mis_confidence": mis_confidence,
                    "layer_type": unique_layer_types[layer_type],
                }
            )
            unit_pbar.set_description(
                f"Best: {best_mis, best_mis_confidence}. "
                f"Worst: {worst_mis, worst_mis_confidence}"
            )
        if len(results) == 0:
            results = pd.DataFrame(layer_results)
        else:
            results = pd.concat(
                (results, pd.DataFrame(layer_results)), axis=0, ignore_index=True
            )
    with open(output_fn, "wb") as f:
        pickle.dump({"scores": results, "layer_types": unique_layer_types}, f)


def main():
    # Images are sorted in ascending order for each neuron, so 0-110 are the 100
    # least activating images
    parser = argparse.ArgumentParser(
        description="Finding the most/least interpretable units in a model."
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Which model to use. Supported models are {0} or all of timm.".format(
            list(accuracies.keys())
        ),
    )
    parser.add_argument(
        "--model-checkpoint",
        type=str,
        required=False,
        default=None,
        help="Path to model checkpoint.",
    )
    parser.add_argument(
        "--num-batches",
        type=int,
        default=20,
        help="The number of batches to collect for each unit.",
    )
    # start / stopping points are indices in a list sorted in ascending order
    parser.add_argument(
        "--start-min",
        type=int,
        help="Starting point for minima selection",
        default=180,
    )
    parser.add_argument(
        "--stop-max",
        type=int,
        help="Stopping point for maxima selection",
        default=-181,
    )
    parser.add_argument(
        "--seed",
        type=int,
        help="Random Seed for pseudorandom numbers used in shuffling of buckets",
        default=42,
    )
    parser.add_argument(
        "--csv",
        action="store_true",
        help="Whether the old method of reading from CSV files should be used.",
    )
    parser.add_argument(
        "--similarity-function",
        type=str,
        required=True,
        choices=get_available_similarity_functions(),
    )
    parser.add_argument("--activations-root", type=str, required=True)
    parser.add_argument(
        "--similarity-function-args", nargs="+", action="append", default=[]
    )
    parser.add_argument(
        "--output-dir", type=str, required=True, help="Where to store the results"
    )
    args = parser.parse_args()

    # Flatten the list of lists of similarity function arguments
    args.similarity_function_args = [
        i for l in args.similarity_function_args for i in l
    ]
    print(args)

    args.stop_min = args.start_min + args.num_batches
    args.start_max = args.stop_max - args.num_batches

    if args.stop_min == 0:
        # To allow indexing in the style of [-20:]. Without this line, the indexing
        # would be interpreted as [-20:0] which equals an empty list.
        args.stop_min = None

    process(args)


if __name__ == "__main__":
    main()
