"""
Like extract_exemplars.py, but this extracts reference images from the top and query
images from the given percentile range.
"""

import argparse
import os
import pickle
from typing import Any, Optional

import numpy as np
import pandas as pd
from sg_utils import (accuracies, get_default_device, get_model_layers,
                      load_model, read_activations_file)
from tqdm import tqdm


def process(args, device: Optional[Any] = None):
    # get all the layers we need and map them to their respective units
    model = load_model(args.model, device=device)
    layers = get_model_layers(model, get_modules=True)

    results = pd.DataFrame(columns=["layer", "unit", "mis", "mis_confidence"])

    model_name = args.model = (
        args.model.replace("/", "_").replace(":", "_").replace("__", "_")
    )

    if args.model_checkpoint:
        model_checkpoint = args.model_checkpoint.replace("//", "/")
        model_name_with_checkpoint = (
            model_name
            + "_"
            + "_".join(model_checkpoint.split("/")[-2:])
            .replace(".pth", "")
            .replace(".pth.tar", "")
        )
    else:
        model_name_with_checkpoint = model_name

    output_fn = os.path.join(
        args.output_dir, f"unit_metadata_{model_name_with_checkpoint}.pkl"
    )

    if os.path.exists(output_fn):
        print(f"Skipping as output file ({output_fn}) already exists.")
        return

    unique_layer_types = dict()
    layer_pbar = tqdm(layers)
    for layer in layer_pbar:
        layer_pbar.set_description(f"Layer {layer}")
        try:
            layer_df = read_activations_file(
                os.path.join(args.activations_root, model_name_with_checkpoint), layer
            )
        except AssertionError:
            print(f"Skipping layer {layer} as it is not available.")
            continue
        units = [int(c) for c in layer_df.columns if "path" not in c]

        layer_type = type(layers[layer]).__name__
        if layer_type not in unique_layer_types:
            unique_layer_types[layer_type] = len(unique_layer_types)

        layer_results = []
        for unit in units:
            activations = layer_df[str(unit)].values
            min_activation = np.min(activations)
            max_activation = np.max(activations)
            median_activation = np.median(activations)
            mean_activation = np.mean(activations)
            std_activation = np.std(activations.astype(np.float64)).astype(np.float32)
            layer_results.append(
                {
                    "layer": layer,
                    "unit": unit,
                    "min_activation": min_activation,
                    "max_activation": max_activation,
                    "median_activation": median_activation,
                    "mean_activation": mean_activation,
                    "std_activation": std_activation,
                    "layer_type": unique_layer_types[layer_type],
                }
            )
        if len(results) == 0:
            results = pd.DataFrame(layer_results)
        else:
            results = pd.concat(
                (results, pd.DataFrame(layer_results)), axis=0, ignore_index=True
            )
    with open(output_fn, "wb") as f:
        pickle.dump({"scores": results, "layer_types": unique_layer_types}, f)


def main():
    parser = argparse.ArgumentParser(
        description="Compute unit metadata (about activations) for a given model."
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Which model to use. Supported models are {0} or all of timm.".format(
            list(accuracies.keys())
        ),
    )
    parser.add_argument(
        "--model-checkpoint",
        type=str,
        required=False,
        default=None,
        help="Path to model checkpoint.",
    )

    parser.add_argument("--activations-root", type=str, required=True)
    parser.add_argument(
        "--output-dir", type=str, required=True, help="Where to store the results"
    )
    args = parser.parse_args()

    device = get_default_device()

    process(args, device)


if __name__ == "__main__":
    main()
