import os
import glob
from io import StringIO
import warnings
import json
import argparse
import regex as re
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
import numpy as np
import pandas as pd
from scipy.optimize import minimize
from datasets import Dataset, DatasetDict

warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)


SUGGESTED_LEVELS = {
    ("AMC8", 1, 12): (1, 1.25),
    ("AMC8", 13, 25): (1.5, 2),
    ("AMC10", 1, 10): (1, 2),
    ("AMC10", 11, 20): (2, 3),
    ("AMC10", 21, 25): (3.5, 4.5),
    ("AMC12", 1, 10): (1.5, 2),
    ("AMC12", 11, 20): (2.5, 3.5),
    ("AMC12", 21, 25): (4.5, 6),
    ("AIME", 1, 5): (3, 3.5),
    ("AIME", 6, 9): (4, 4.5),
    ("AIME", 10, 12): (5, 5.5),
    ("AIME", 13, 15): (6, 7),
}


def list_item_difficulty_records(folder_path):
    # Create a pattern to match all CSV files in the folder
    pattern = os.path.join(folder_path, "*.csv")
    # Use glob to find all files matching the pattern
    csv_files = glob.glob(pattern)
    # Extract the base filename without the '.csv' extension
    file_names = [os.path.splitext(os.path.basename(file))[0] for file in csv_files]
    return [tuple(file_name.split("_")) for file_name in file_names]


def parse_item_difficulty_csv(file_path, contest):
    # Read the entire file into a single string
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()
    # The csv is type II, i.e., processed from PDF or converted
    if content.startswith("Question,Rate"):
        df = pd.read_csv(StringIO(content))
        # Assert that the DataFrame has the expected number of rows
        if contest.startswith("AMC"):
            assert len(df) == 25
        elif contest.startswith("AIME"):
            assert len(df) == 15
        else:
            raise ValueError(f"Unsupported contest: {contest}")
        # Return the values in column Rate as a list, where Rate means correct percentages
        return df["Rate"].to_list()
    # The csv is type I, i.e., downloaded from the website directly (decrypted)
    # Note on known issues on the item difficulty CSV files:
    # AMC10B 2011 No.17, C is the correct answer, the CSV file miss an asterisk
    else:
        # Split the content at the third last empty line, since there are two empty lines at the end of the file
        csv_content = content.rsplit("\n\n", 2)[-2]
        # Use StringIO to convert the CSV string into a file-like object for pandas
        df = pd.read_csv(StringIO(csv_content))
        if contest.startswith("AMC"):
            # Assert that the DataFrame has the expected number of rows
            assert len(df) == 25
            # Using a lambda function to find and process the item with an asterisk
            starred_numbers = (
                df[["A", "B", "C", "D", "E"]]
                .apply(
                    lambda row: float(
                        next(
                            (
                                item.rstrip("*")
                                for item in row.astype(str)
                                if item.endswith("*")
                            ),
                            None,
                        )
                    ),
                    axis=1,
                )
                .tolist()
            )
            # Return the list of starred numbers, which are correpond to the correct answers
            return starred_numbers
        elif contest.startswith("AIME"):
            # Filter rows where the sum of the columns A, B, C, and D is within a tolerance of 100
            # This is a known issue for the item difficulty CSV files for AIME
            # Compute the sum across the specified columns for each row
            row_sums = df[["A", "B", "C", "D"]].sum(axis=1)
            # Define a tolerance for how close to 100 the sums need to be
            tolerance = 1.0  # Adjust this value as needed
            # Filter rows where the sum is within the specified tolerance of 100
            df_filtered = df[row_sums.between(100 - tolerance, 100 + tolerance)]
            # Assert that the filtered DataFrame has the expected number of rows
            assert len(df_filtered) == 15
            # Return the values in column A as a list, where column A means correct answers
            return df_filtered["A"].to_list()
        else:
            raise ValueError(f"Unsupported contest: {contest}")


def parse_all_item_difficulties():
    folder_path = "./data/AMC/item_difficulty"
    # Define the DataFrame with the specified structure
    df = pd.DataFrame(
        columns=["contest", "year", "problem_index", "correct_percentage"]
    ).astype(
        {
            "contest": "string",
            "year": "float64",  # Assuming year can be float, e.g., 2021a -> 2021.5
            "problem_index": "int64",
            "correct_percentage": "float64",
        }
    )
    # Assuming list_item_difficulty_records and parse_item_difficulty_csv functions are defined and available
    for contest, year in list_item_difficulty_records(folder_path):
        file_path = os.path.join(folder_path, f"{contest}_{year}.csv")
        correct_percentages = parse_item_difficulty_csv(
            file_path, contest
        )  # Assuming this function returns a list of correct percentages
        # Assuming each correct rate corresponds to a problem index, starting from 1
        problem_indices = list(range(1, len(correct_percentages) + 1))
        # Create a temporary DataFrame for the current file's data with explicit data types
        temp_df = pd.DataFrame(
            {
                "contest": pd.Series(
                    [contest] * len(correct_percentages), dtype="string"
                ),
                "year": pd.Series(
                    [float(year) if year != "2021a" else 2021.5]
                    * len(correct_percentages),
                    dtype="float64",
                ),
                "problem_index": pd.Series(problem_indices, dtype="int64"),
                "correct_percentage": pd.Series(correct_percentages, dtype="float64"),
            }
        )
        # Append the temporary DataFrame to the main DataFrame
        df = pd.concat([df, temp_df], ignore_index=True)
    return df


def fit_item_response(df):
    # Get the correct rates between 0 and 1
    correct_rate = df["correct_percentage"] / 100.0
    # Calculate the inverse sigmoid of the correct rates
    df["inverse_sigmoid"] = np.log(correct_rate / (1 - correct_rate))
    # Using a lambda function to assign difficulty based on the new flat dictionary structure
    df["level_low"], df["level_high"] = zip(
        *df.apply(
            lambda row: next(
                (
                    (levels[0], levels[1])
                    for (
                        contest_prefix,
                        start_idx,
                        end_idx,
                    ), levels in SUGGESTED_LEVELS.items()
                    if row["contest"].startswith(contest_prefix)
                    and start_idx <= row["problem_index"] <= end_idx
                ),
                (None, None),
            ),
            axis=1,
        )
    )
    broadcast_matrix = np.stack(
        [
            df["contest"].str.startswith(contest_prefix).values
            for contest_prefix in ["AMC8", "AMC10", "AMC12", "AIME"]
        ]
    )
    sol = minimize(
        lambda x: np.sum(
            np.maximum(
                (x[1] + x[2:]) @ broadcast_matrix
                - (x[0] * df["inverse_sigmoid"])
                - df["level_low"],
                0,
            )
        )
        + np.sum(
            np.maximum(
                df["level_high"]
                - (x[1] + x[2:]) @ broadcast_matrix
                + (x[0] * df["inverse_sigmoid"]),
                0,
            )
        ),
        (0.5, -2.0, 1.5, 2.75, 3.75, 5),
        bounds=[(0.1, 10), (-7, 1), (0.5, 2.5), (0.5, 5), (1, 6.5), (2.5, 7.5)],
    )
    df["level"] = (sol.x[1] + sol.x[2:]) @ broadcast_matrix - sol.x[0] * df[
        "inverse_sigmoid"
    ]
    df["rating"] = (df["level"] - df["level"].mean()) * (sol.x[0] * 400) + 1500
    return df.sort_values(["rating"]).drop(
        columns=["inverse_sigmoid", "level_low", "level_high"]
    )


def parse_all_problems(item_difficulty, by_answer=False):
    problems = []
    for (
        contest,
        year,
        problem_index,
        correct_percentage,
        level,
        rating,
    ) in item_difficulty[
        ["contest", "year", "problem_index", "correct_percentage", "level", "rating"]
    ].values:
        json_path = f"./data/AMC/problems/{contest}/{contest}_{str(int(year)) if year != 2021.5 else '2021_Fall'}_{problem_index}.json"
        # Assert that the JSON file exists
        assert os.path.exists(json_path), f"File {json_path} does not exist."
        data = json.load(open(json_path))
        # Parse the question
        question = data["problem"]
        # First parse as HTML and extract plain text to remove HTML tags and entities
        question = BeautifulSoup(question, "html.parser").get_text()
        # Then for AMC contests, remove the choices which are the last line
        if contest.startswith("AMC"):
            question = "\n".join(question.split("\n")[:-1])
            # This may cause issues when the choices are not in a new line
            if not question:
                continue
        # Finally remove leading and trailing whitespaces and replace newline and multiple whitespaces with a single whitespace
        question = re.sub(r"\s+", " ", question.strip())
        # Parse the answers
        answers = []
        solution_index = 0
        while True:
            if f"solution_{solution_index}" not in data:
                break
            solution = data[f"solution_{solution_index}"]
            # Solution need a throughtout cleaning and standardization
            # First parse as HTML and extract plain text to remove HTML tags and entities
            solution = BeautifulSoup(solution, "html.parser").get_text()
            # Try to remove the author's name at the end of the solution (consider ~/- as a prefix and the last 48 characters)
            solution = re.sub(r"[-~][\w\s().]+$", "", solution)
            # Finally remove leading and trailing whitespaces and replace newline and multiple whitespaces with a single whitespace
            solution = re.sub(r"\s+", " ", solution.strip())
            # Here we implement a simple solution to extract the boxed answer
            # In matching, we consider \boxed{...} whereas there is at most one pair of curly braces inside
            boxed_matches = re.findall(r"\\boxed{(?:[^{}]|{[^{}]*})*}", solution)
            # We require that there is exactly one boxed answer
            if len(boxed_matches) == 1:
                # Try to extract the number inside the boxed answer
                number_matches = re.findall(r"(\d+(?:\.\d+)?)", boxed_matches[0])
                # We expect that there is exactly one number inside the boxed answer
                if len(number_matches) == 1:
                    # Remove leading and trailing zeros of answer number, which can be integer or decimal
                    answer_number = re.sub(r"^0+|(?<=\.\d)0+$", "", number_matches[0])
                    # Replace the boxed answer with the processed answer number
                    solution = re.sub(
                        re.escape(boxed_matches[0]),
                        r"\\boxed{" + answer_number + r"}",
                        solution,
                    )
                    # If all checks pass, append the answer number to the list of answers
                    answers.append(solution)
            solution_index += 1
        # Append sample per problem
        # If there is no answer extracted, ignore the problem
        if not by_answer and len(answers) > 0:
            problems.append(
                {
                    "contest": contest,
                    "year": year,
                    "problem_index": problem_index,
                    "correct_percentage": correct_percentage,
                    "level": level,
                    "rating": rating,
                    "question": question,
                    "answer": answers,
                }
            )
        # Append sample per answer
        else:
            for answer in answers:
                problems.append(
                    {
                        "contest": contest,
                        "year": year,
                        "problem_index": problem_index,
                        "correct_percentage": correct_percentage,
                        "level": level,
                        "rating": rating,
                        "question": question,
                        "answer": answer,
                    }
                )
    return pd.DataFrame(problems)


def main(huggingface_path, version_tag):
    item_difficulty = parse_all_item_difficulties()
    item_difficulty = fit_item_response(item_difficulty)
    dataset_by_answer = parse_all_problems(item_difficulty, by_answer=True)
    DatasetDict(
        {
            "default": Dataset.from_pandas(dataset_by_answer.reset_index(drop=True)),
        }
    ).push_to_hub(huggingface_path, version_tag)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--huggingface_path",
        type=str,
        default="mcding-org/Easy2Hard-AMC",
        help="Path to the Hugging Face dataset",
    )
    argparser.add_argument(
        "--version_tag",
        type=str,
        default="v1",
        help="Version tag for the Hugging Face dataset",
    )
    args = argparser.parse_args()

    main(args.huggingface_path, args.version_tag)
