import argparse
import re
import tiktoken
import asyncio
from aiolimiter import AsyncLimiter
from openai import AsyncAzureOpenAI, RateLimitError
from tqdm.asyncio import tqdm_asyncio
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
import random, itertools, pandas as pd
import yaml

API_VERSION = "2024-02-15-preview"

REQUEST_PRE_MINUTE, MAX_CONCURRENT_REQUEST = 250, 10
MAX_RETRY = 20

# Rate limits
rate_limiter = AsyncLimiter(REQUEST_PRE_MINUTE, 60)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUEST)

# Token counter
encoding = tiktoken.get_encoding("cl100k_base")


# Dataset formatters
def gsm8k_formatter(sample):
    return {
        "question": sample["question"],
        "answer": sample["answer"],
    }


def arc_formatter(sample):
    return {
        # Removed "Question: {question}.\nAnswer:" template
        "question": sample["question"],
        # Add period to the end of each choice
        "choices": [choice + "." for choice in sample["choices"]["text"]],
        "target": sample["choices"]["label"].index(sample["answerKey"]),
    }


def winogrande_formatter(sample):
    idx = sample["sentence"].index("_")
    return {
        "question": sample["sentence"][:idx],
        "choices": [
            sample["option1"] + sample["sentence"][idx + 1 :],
            sample["option2"] + sample["sentence"][idx + 1 :],
        ],
        "target": {"1": 0, "2": 1}[sample["answer"]],
    }


def hellaswag_formatter(sample):
    preprocess = lambda text: re.sub(
        "\\[.*?\\]", "", text.replace(" [title]", ". ").strip()
    ).replace("  ", " ")
    ctx = sample["ctx_a"] + " " + sample["ctx_b"].capitalize()
    return {
        "question": preprocess(sample["activity_label"] + ": " + ctx),
        "choices": [preprocess(ending) for ending in sample["endings"]],
        "target": int(sample["label"]),
    }


dataset_formatters = {
    "GSM8K": gsm8k_formatter,
    "ARC": arc_formatter,
    "Winogrande": winogrande_formatter,
    "HellaSwag": hellaswag_formatter,
}


# OpenAI API functions
async def create_and_send_prompt(client, model, dataset, sample0, sample1):
    prompt_kwargs = {
        "question0": sample0["question"],
        "question1": sample1["question"],
    }
    if "answer" in sample0:
        prompt_kwargs |= {
            "answer0": sample0["answer"],
            "answer1": sample1["answer"],
        }
    if "choices" in sample0:
        prompt_kwargs |= {
            "choices0": " ".join(
                [f"({idx+1}) {choice}" for idx, choice in enumerate(sample0["choices"])]
            ),
            "choices1": " ".join(
                [f"({idx+1}) {choice}" for idx, choice in enumerate(sample1["choices"])]
            ),
        }
        prompt_kwargs |= {
            "target0": f"({sample0['target']+1}) {sample0['choices'][sample0['target']]}",
            "target1": f"({sample1['target']+1}) {sample1['choices'][sample1['target']]}",
        }

    prompt = yaml.safe_load(open(f"./prompts/{dataset}.yaml"))
    system_message = prompt["system_prompt"]
    user_message = prompt["user_prompt"].format(
        **prompt_kwargs,
    )

    message_text = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]
    data = {
        "model": model,
        "messages": message_text,
        "max_tokens": 800,
        "temperature": 0.3,
        "frequency_penalty": 0.1,
        "presence_penalty": 0,
        "top_p": 0.9,
        "stop": None,
    }
    counter = 0
    while True:
        async with rate_limiter:
            async with semaphore:
                try:
                    result = await client.chat.completions.create(**data)
                    content = result.choices[0].message.content
                    if isinstance(content, str):
                        patterns = [
                            r"\[\s*(-?\d+(?:\.\d*)?),\s*(-?\d+(?:\.\d*)?)\s*\]",
                            r"Assistant A: (\d+(?:\.\d+)?)\D*Assistant B: (\d+(?:\.\d+)?)",
                        ]
                        for pattern in patterns:
                            try:
                                score1, score2 = tuple(
                                    map(float, re.findall(pattern, content)[-1])
                                )
                                return [content, score1, score2, user_message]
                            except:
                                pass
                        raise Exception("Parsing fail!")
                except RateLimitError:
                    await asyncio.sleep(5)
                except Exception as e:
                    print(f"Error: {e}")
                    counter += 1
                    if counter >= MAX_RETRY:
                        print(
                            "Max retry reached on question0: ",
                            sample0["question"][:60],
                            " ...",
                        )
                        print(
                            "Max retry reached on question1: ",
                            sample1["question"][:60],
                            " ...",
                        )
                        return ["", float("NaN"), float("NaN"), user_message]


async def evaluate_difficulty_pair(
    client, model, dataset, sample0, sample1, num_repeats
):
    results = [
        create_and_send_prompt(client, model, dataset, sample0, sample1)
        for _ in range(num_repeats)
    ]
    difficulty_results = await asyncio.gather(*results)
    return difficulty_results


async def run_evaluation(
    dataset, dataset_name, pair_index, num_repeats, api_endpoint, api_model
):
    # Setup the client
    client = AsyncAzureOpenAI(
        azure_endpoint=api_endpoint,
        api_key=API_BASE[api_endpoint],
        api_version=API_VERSION,
    )
    async with client:  # Assuming the client has an async context manager
        tasks = [
            evaluate_difficulty_pair(
                client,
                api_model,
                dataset_name,
                dataset[int(index[0])],
                dataset[int(index[1])],
                num_repeats,
            )
            for index in pair_index
        ]
        accuracies = await tqdm_asyncio.gather(*tasks, desc="Evaluating samples")
    return accuracies


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--dataset",
        type=str,
        default="GSM8K",
        help="Name of the Dataset",
    )
    argparser.add_argument(
        "--chunk_id",
        type=str,
        help="Index of the Chunk",
    )
    argparser.add_argument(
        "--endpoint",
        type=str,
        help="Name of the Endpoint",
    )
    argparser.add_argument(
        "--model",
        type=str,
        help="Name of the Model",
    )
    args = argparser.parse_args()

    num_repeats = 3
    chunk_size = 100

    dataset = load_dataset(
        f"mcding-org/Easy2Hard-{args.dataset}",
        "v1",
        split="test" if args.dataset in ["GSM8K", "ARC"] else "validation",
    )

    # Do not change the seed
    random.seed(7)
    pair_index_table = random.sample(
        list(itertools.combinations(np.arange(dataset.num_rows), 2)), 100000
    )
    assert chunk_size * (int(args.chunk_id) + 1) <= 100000
    pair_index = pair_index_table[
        chunk_size * int(args.chunk_id) : chunk_size * (int(args.chunk_id) + 1)
    ]
    df_dict = {
        "question0": [dataset[int(pair[0])]["sorted_index"] for pair in pair_index],
        "question1": [dataset[int(pair[1])]["sorted_index"] for pair in pair_index],
    }
    df_dict["prompt"] = ""
    for m in range(num_repeats):
        df_dict[f"explanation_{m}"] = ""
        df_dict[f"score_QuestionA_{m}"] = ""
        df_dict[f"score_QuestionB_{m}"] = ""
    df = pd.DataFrame(df_dict)

    dataset = dataset.map(dataset_formatters[args.dataset])

    accuracies = asyncio.run(
        run_evaluation(
            dataset,
            args.dataset,
            pair_index,
            num_repeats=num_repeats,
            api_endpoint=args.endpoint,
            api_model=args.model,
        )
    )
    for n in range(df.shape[0]):
        df.loc[n, "prompt"] = accuracies[n][0][3]
        for m in range(num_repeats):
            df.loc[n, f"explanation_{m}"] = accuracies[n][m][0]
            df.loc[n, f"score_QuestionA_{m}"] = accuracies[n][m][1]
            df.loc[n, f"score_QuestionB_{m}"] = accuracies[n][m][2]

    DatasetDict({"default": Dataset.from_pandas(df)}).push_to_hub(
        "mcding-org/Easy2Hard-" + args.dataset + "-GPT", f"chunk_{args.chunk_id}"
    )
