import sklearn.metrics as metrics
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix
from typing import List, Dict
import numpy as np
from matplotlib import pyplot as plt
import argparse
import os
from typing import Tuple
import pickle


def compute_accuracy(data: Dict) -> torch.Tensor:
    epoch_keys = sorted(list(data[list(data.keys())[0]].keys()))
    seed_keys = sorted(list(data.keys()))
    k_keys = sorted(list(data[list(data.keys())[0]][epoch_keys[0]].keys()))
    all_accuracies = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    for i_epoch, epoch in enumerate(epoch_keys):
        for i_seed, seed in enumerate(seed_keys):
            for i_key, k in enumerate(k_keys):
                predict_labels = data[seed][epoch][k]["predicted_label"]
                gt_labels = data[seed][epoch][k]["gt_label"]
                comp = [x == y for x, y in zip(predict_labels, gt_labels)]
                accuracy = sum(comp) / len(comp)
                all_accuracies[i_epoch, i_key, i_seed] = accuracy

    return all_accuracies


def plot_generic(
    ft_outputs: torch.Tensor = None,
    k_range: List[int] = [18, 32, 48, 64],
    dataset: str = "SMS",
    metric: str = "accuracy",
    path_to_save: str = "./accuracies",
    prefix: str = "",
):
    inputs = ""
    y_values = []
    if ft_outputs is not None:
        inputs += "FT"
        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        ft_outputs = torch.mean(ft_outputs, dim=-1)
        for epoch in range(ft_outputs.shape[0]):
            #  10 15 20 50
            if epoch == 0:
                id = 5
            elif epoch == 1:
                id = 10
            elif epoch == 2:
                id = 15
            elif epoch == 3:
                id = 20
            elif epoch == 4:
                id = 50
            y_values.append(
                (
                    ft_outputs[epoch, :].tolist(),
                    f"Finetuned Model (Epoch {id}))",
                    colors[epoch],
                )
            )
    for i, y in enumerate(y_values):
        plt.plot(k_range, y[0], label=y[1], color=y[2])

    plt.xlabel("Number of In-Context Examples")
    plt.title(f"{metric} Rate for {dataset} (2-class) Dataset")
    plt.ylabel(f"{metric}")
    plt.legend()
    plt.show()

    # save the plot
    plt.savefig(f"{path_to_save}/{dataset}{inputs}_{metric}_{prefix}.png")
    plt.clf()


if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--path_to_file",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="EleutherAI/gpt-neo-125m",
    )
    parser.add_argument(
        "--method",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="./accuracies",
    )
    parser.add_argument(
        "--file_name",
        type=str,
        default="accuracies.pkl",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="hatespeech_18",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--do_plot",
        action="store_true",
    )
    parser.add_argument(
        "--run_id",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
    )

    args = parser.parse_args()
    model_name_split = args.model_name.split("/")[-1]

    if os.path.exists(args.data_dir) is False:
        os.mkdir(args.data_dir)

    if args.method in ["adaptor", "lrsolver"]:
        path_to_save = f"{args.data_dir}/{args.dataset}/{model_name_split}/{args.method}/{args.run_id}/{args.checkpoint}/"
    else:
        path_to_save = (
            f"{args.data_dir}/{args.dataset}/{model_name_split}/{args.method}/"
        )
    print(path_to_save)
    if os.path.exists(path_to_save) is False:
        os.makedirs(path_to_save)

    data = pd.read_pickle(args.path_to_file)
    all_accuracies = compute_accuracy(data)

    # save all_accuracies as a pkl file
    pickle_name = os.path.join(path_to_save, args.file_name)

    with open(pickle_name, "wb") as f:
        pickle.dump(all_accuracies, f)

    if args.do_plot:
        plot_generic(
            all_accuracies,
            dataset=args.dataset,
            metric="accuracy",
            path_to_save=args.path_to_save,
            prefix=args.prefix,
        )
