import json
import re
from dataclasses import dataclass
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union

import sentence_transformers
import torch
from fire import Fire
from IPython import embed
from mteb import MTEB
from nltk import word_tokenize
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling, WordEmbeddings
from torchtyping import TensorType as TT

from all_but_the_top import AllButTheTop
from eval_utils import (
    UnigramProb,
    WrappedTokenizer,
    load_unigram_prob_enwiki_vocab_min200,
    load_word2vec_model,
    remove_unused_words,
)
from modeling import CustomPooling
from SIF import SIF
from zipfian_whitening import UniformWhitening, ZipfianWhitening

# Downloaded from: https://github.com/kawine/usif/raw/71ffef5b6d7295c36354136bfc6728a10bd25d32/enwiki_vocab_min200.txt
PATH_ENWIKI_VOCAB_MIN200 = "data/enwiki_vocab_min200/enwiki vocab min200.txt"
TRANSFORM_CONFIG = {
    "normal": {
        "whitening_transformer_class": None,
        "pooling": ["mean"],
    },  # no whitening. normal mean pooling.
    "uniform_whitening": {
        "whitening_transformer_class": UniformWhitening,
        "pooling": ["centering_only", "whitening"],
    },
    "zipfian_whitening": {
        "whitening_transformer_class": ZipfianWhitening,
        "pooling": ["centering_only", "whitening"],
    },
    "abtp": {
        "whitening_transformer_class": AllButTheTop,
        "pooling": ["component_removal"],
    },
    # "sif": {
    #     "whitening_transformer_class": SIF,  # TODO: implement SIF
    #     "pooling": ["sif"],
    # },
}


def evaluate(
    model: SentenceTransformer,
    model_name: str,
    whitening_transformer: Union[
        UniformWhitening, ZipfianWhitening
    ],  # TODO: this is not whitening anymore, it's more general transformer. Change the name.
    pooling_mode: str,
    embedding_layer_index: int = 0,
    pooling_layer_index: int = 1,
    topk: Optional[int] = None,
):
    # task_name = "STSBenchmark"  # TODO: make it more general
    task_name = "SICK-R"
    # save_dir name rule:
    # {model_name}/{task_name}/{whitening_transformer_name (e.g., zipfian_whitening, ...)}/{pooling_mode}/{topk}
    model_name = Path(model_name).name  # remove "SentenceTransformer/" prefix
    if whitening_transformer is None:
        whitening_name = "normal"
    elif isinstance(whitening_transformer, ZipfianWhitening):
        whitening_name = "zipfian_whitening"
    elif isinstance(whitening_transformer, UniformWhitening):
        whitening_name = "uniform_whitening"
    elif isinstance(whitening_transformer, AllButTheTop):
        whitening_name = "abtp"
    elif isinstance(whitening_transformer, SIF):
        whitening_name = "sif"
    else:
        raise NotImplementedError(
            'Only "ZipfianWhitening" and "UniformWhitening" and "AllButTheTop" and "SIF" are supported.'
        )
    save_dir_name = (
        f"results/{model_name}/{task_name}/{whitening_name}/{pooling_mode}"
        if topk is None
        else f"results/{model_name}/{task_name}/{whitening_name}/{pooling_mode}/{topk}"
    )
    pooling = CustomPooling(
        word_embedding_dimension=model[
            embedding_layer_index
        ].get_word_embedding_dimension(),
        pooling_mode=pooling_mode,
        whitening_transformer=whitening_transformer,
    )
    model[pooling_layer_index] = pooling
    task = MTEB(tasks=[task_name])
    results = task.run(
        model,
        output_folder=save_dir_name,
    )
    print("#" * 50)
    print(f"Done {model_name} with {whitening_name} and {pooling_mode} pooling.")
    print("#" * 50)
    pprint(results)


def main(model_name: str, topk: Optional[int] = None) -> None:
    print(f"topk: {topk}")
    # Note: maybe won't work for BERT-based models, need model specific config
    embedding_layer_index = 0
    pooling_layer_index = 1

    if model_name == "models/GoogleNews-vectors-negative300":
        model: SentenceTransformer = load_word2vec_model(model_name)
    else:
        model = SentenceTransformer(model_name)

    model.tokenizer.stop_words = {}
    model.tokenizer.do_lower_case = True
    model.tokenizer = WrappedTokenizer(model.tokenizer)
    model_vocab_size = model[embedding_layer_index].emb_layer.weight.shape[
        0
    ]  # note that vocab_size of the model and tokenizer might differ due to the special tokens like padding.
    unigramprob: UnigramProb = load_unigram_prob_enwiki_vocab_min200(
        model.tokenizer, model_vocab_size, topk=topk
    )
    unigramprob_tensor: TT["num_words"] = unigramprob.prob.to(model.device)
    unsued_vocab_ids: set[int] = unigramprob.unused_vocab_ids
    params = {
        "model": model,
        "model_name": model_name,
        "whitening_transformer": None,
        "embedding_layer_index": embedding_layer_index,
        "pooling_layer_index": pooling_layer_index,
        "topk": topk,
    }
    # To reduce the noise for the whitening, remove the unused words from the embeddings and unigram probabilities.
    # This is common setting for all the whitening methods to be performed EXCEPT for SIF.
    # For SIF, common components are computed on sentence embeddings, thus unsued words do not affect the result.
    embedding_for_whitening, unigramprob_tensor = remove_unused_words(
        unsued_vocab_ids,
        model[embedding_layer_index].emb_layer.weight,
        unigramprob_tensor,
    )
    model.tokenizer.original_tokenizer.stop_words = {  # HACK: Setting unused words as stop words; the model never recognizes the token.
        model.tokenizer.vocab[index] for index in unsued_vocab_ids
    }

    # for non-sif methods
    for trnaform_name in TRANSFORM_CONFIG:
        params["pooling_mode"]: List[str] = TRANSFORM_CONFIG[trnaform_name]["pooling"]
        whitening_transformer = TRANSFORM_CONFIG[trnaform_name][
            "whitening_transformer_class"
        ]
        whitening_transformer = (
            None
            if whitening_transformer is None
            else whitening_transformer().fit(
                embedding_for_whitening, p=unigramprob_tensor
            )
        )
        params["whitening_transformer"] = whitening_transformer
        for pooling_mode in params["pooling_mode"]:
            params["pooling_mode"] = pooling_mode
            evaluate(**params)

    # For sif. Since it needs the statistics of the "test" dataset, logic is separated FOR NOW.
    # TODO: integrate this with the above loop.
    unigramprob: UnigramProb = load_unigram_prob_enwiki_vocab_min200(
        model.tokenizer, model_vocab_size, topk=topk
    )
    unsued_vocab_ids: set[int] = unigramprob.unused_vocab_ids
    unigramprob_tensor: TT["num_words"] = unigramprob.prob.to(model.device)
    # set p(w) = 0 for the unused words. this does not affect the SIF computation, since unused words are blocked by the tokenizer.
    unigramprob_tensor[list(unsued_vocab_ids)] = 0
    unigramprob_tensor = unigramprob_tensor / unigramprob_tensor.sum()
    assert unigramprob_tensor.sum() == 1
    model.tokenizer.original_tokenizer.stop_words = {  # HACK: Setting unused words as stop words; the model never recognizes the token.
        model.tokenizer.vocab[index] for index in unsued_vocab_ids
    }
    sif = SIF(model, data_split="test")  # use test set for sentene-level task
    sif.fit(None, unigramprob_tensor)  # W won't be used in SIF
    pooling_mode = "sif_w_component_removal"
    evaluate(
        model=model,
        model_name=model_name,
        whitening_transformer=sif,
        pooling_mode=pooling_mode,
        embedding_layer_index=embedding_layer_index,
        pooling_layer_index=pooling_layer_index,
        topk=topk,
    )


if __name__ == "__main__":
    Fire(main)
