import os
import json
import gzip
import tqdm
import glob
import re
import csv
from src.utils.file_utils import load_jsonl
from src.utils.data_utils import extract_answer_math
from src.utils.eval_utils import delete_extra_zero

DATA_PATH = {
    "gsm": "opencompass/gsm8k/test.jsonl",
    "gsm_dev": "opencompass/gsm8k/train.jsonl",
    "svamp": "svamp/svamp.json",
    "csqa": "opencompass/commonsenseqa/dev_rand_split.jsonl",
    # "csqa_dev": "opencompass/commonsenseqa/train_rand_split.jsonl",
    "csqa_dev": "opencompass/commonsenseqa/train_rand_split.1000.jsonl",
    "aqua": "AQuA/AQuA.json",
    "aqua_dev": "AQuA/dev.json",
    "sat": "agieval/sat-math.jsonl",
    "multiarith": "arith/multiarith.json",
    "addsub": "arith/addsub.json",
    "singleeq": "arith/singleeq.json",
    "MATH": "MATH/test.json",
    "MATH500": "MATH/test500.json",
    "strategyqa": "opencompass/strategyqa/strategyQA_train.json",
    "strategyqa_dev": "opencompass/strategyqa/strategyQA_train.json",
    "coin_flip": "symbol/coin_flip.json",
    "last_letters": "symbol/last_letters.json",
    # "humaneval": "eval/codex_humaneval/HumanEval.jsonl.gz",
    "humaneval": "opencompass/humaneval/human-eval-v2-20210705.jsonl",
    "bbh": "opencompass/BBH/data",
    "bbh_symbol": "opencompass/BBH/data",
    "bbh_symbol_dev": "opencompass/BBH/data",
    "triviaqa": "opencompass/triviaqa/trivia-dev.qa.csv",
    "nq": "opencompass/nq/nq-dev.qa.csv",
    "arcc": "opencompass/ARC/ARC-c/ARC-Challenge-Test.jsonl",
    "arce": "opencompass/ARC/ARC-e/ARC-Easy-Test.jsonl",
    "hellaswag": "opencompass/hellaswag/hellaswag.jsonl",
    "asdiv": "asdiv/test.jsonl",
    "mmlu_math": "opencompass/mmlu/test",
    # pal
    "gsm_pal": "opencompass/gsm8k/test.jsonl",
    "svamp_pal": "svamp/svamp.json",
    "multiarith_pal": "arith/multiarith.json",
    "addsub_pal": "arith/addsub.json",
    "singleeq_pal": "arith/singleeq.json",
    # tora
    "gsm_tora": "opencompass/gsm8k/test.jsonl",
    "svamp_tora": "svamp/svamp.json",
    "multiarith_tora": "arith/multiarith.json",
    "addsub_tora": "arith/addsub.json",
    "singleeq_tora": "arith/singleeq.json",
    # data set
    "arith": "arith",
    "mwp": "mwp",
    "mwp_wo_gsm": "mwp_wo_gsm",
    "reason512": "reason512",
    "reason1024": "reason1024",
    "math": "math",
    "agieval": "agieval",
    "qa": "qa",
    "arc": "arc",
}

FEW_SHOT_PATH = {
    "gsm": "prompt/gsm8k/cot.txt",
    "gsm_tora": "prompt/gsm8k/tora.txt",
    "svamp": "prompt/gsm8k/cot.txt",
    "aqua": "prompt/aqua/cot.txt",
    "csqa": "prompt/csqa/cot.txt",
    "multiarith": "prompt/gsm8k/cot.txt",
    "addsub": "prompt/gsm8k/cot.txt",
    "singleeq": "prompt/gsm8k/cot.txt",
    "strategyqa": "prompt/strategyqa/cot.txt",
    "gsm_pal": "prompt/gsm8k/pal.txt",
    "svamp_pal": "prompt/gsm8k/pal.txt",
    "multiarith_pal": "prompt/gsm8k/pal.txt",
    "addsub_pal": "prompt/gsm8k/pal.txt",
    "singleeq_pal": "prompt/gsm8k/pal.txt",
    "gsm_tora": "prompt/gsm8k/pal.txt",
    "svamp_tora": "prompt/gsm8k/pal.txt",
    "multiarith_tora": "prompt/gsm8k/pal.txt",
    "addsub_tora": "prompt/gsm8k/pal.txt",
    "singleeq_tora": "prompt/gsm8k/pal.txt",
    "coin_flip": "prompt/coin_flip/cot.txt",
    "last_letters": "prompt/last_letters/cot.txt",
    "bbh": "eval/bbh/cot-prompts",
    "MATH": "prompt/MATH/cot.txt",
    "MATH_tora": "prompt/MATH/tora.txt",
    "MATH500": "prompt/MATH/cot.txt",
    "MATH8shot": "prompt/MATH/cot_8shot.txt",
}

EXTRACT_PROMPT = {
    # "csqa": "\nAmong A through E, the answer is",
    # "csqa": "\nAmong (A) through (E), the answer is",
    "csqa": "\nTherefore, among (A) through (E), the correct option is (",
    "csqa_dev": "\nTherefore, among (A) through (E), the correct option is (",
    "strategyqa": "\nTherefore, among yes and no, the answer is",
    "strategyqa_dev": "\nTherefore, among yes and no, the answer is",
    # "strategyqa": "\nThe answer (Yes or No) is",
    # "strategyqa_dev": "\nThe answer (Yes or No) is",
    "coin_flip": "\nTherefore, the answer (Yes or No, yes means heads up, no means tails up) is",
    "last_letters": "\nTherefore, the answer (last letters) is",
    "bbh_mcq": "\nTherefore, the correct option is (",
    "aqua": "\nTherefore, the correct option is (",
    "aqua_dev": "\nTherefore, the correct option is (",
    "sat": "\nTherefore, the correct option is (",
    "arcc": "\nTherefore, the correct option is (",
    "arce": "\nTherefore, the correct option is (",
    "bbh_ff": "\nTherefore, the answer is",
    "bbh_symbol": "\nTherefore, the correct option is (",
    "bbh_symbol_dev": "\nTherefore, the correct option is (",
    "bbh": "",
}

DATA_SET = {
    "arith": ["multiarith", "addsub", "singleeq"],
    "mwp": ["multiarith", "addsub", "singleeq", "gsm", "svamp"],
    "mwp_pal": ["multiarith_pal", "addsub_pal", "singleeq_pal", "gsm_pal", "svamp_pal"],
    "mwp_tora": [
        "multiarith_tora",
        "addsub_tora",
        "singleeq_tora",
        "gsm_tora",
        "svamp_tora",
    ],
    "mwp_wo_gsm": ["multiarith", "addsub", "singleeq", "svamp"],
    "no_mwp": ["csqa", "strategyqa", "coin_flip", "last_letters"],
    "cs": ["csqa", "strategyqa"],
    "symbol": ["coin_flip", "last_letters"],
    "reason512": ["gsm", "svamp", "aqua", "bbh_symbol", "humaneval"],
    "reason1024": ["csqa", "strategyqa"],
    "dev512": ["gsm_dev", "bbh_symbol_dev", "aqua_dev"],
    "dev1024": ["csqa_dev", "strategyqa_dev"],
    "math": ["multiarith", "addsub", "singleeq", "gsm", "svamp", "MATH500"],
    "agieval": ["sat", "aqua"],
    "qa": ["triviaqa", "nq"],
    "arc": ["arcc", "arce"],
}

bbh_multiple_choice_sets = [
    "temporal_sequences",
    "disambiguation_qa",
    "date_understanding",
    "tracking_shuffled_objects_three_objects",
    "penguins_in_a_table",
    "geometric_shapes",
    "snarks",
    "ruin_names",
    "tracking_shuffled_objects_seven_objects",
    "tracking_shuffled_objects_five_objects",
    "logical_deduction_three_objects",
    "hyperbaton",
    "logical_deduction_five_objects",
    "logical_deduction_seven_objects",
    "movie_recommendation",
    "salient_translation_error_detection",
    "reasoning_about_colored_objects",
]
bbh_free_form_sets = [
    "multistep_arithmetic_two",
    "navigate",
    "dyck_languages",
    "word_sorting",
    "sports_understanding",
    "boolean_expressions",
    "object_counting",
    "formal_fallacies",
    "causal_judgement",
    "web_of_lies",
]
mmlu_all_sets = [
    "college_biology",
    "college_chemistry",
    "college_computer_science",
    "college_mathematics",
    "college_physics",
    "electrical_engineering",
    "astronomy",
    "anatomy",
    "abstract_algebra",
    "machine_learning",
    "clinical_knowledge",
    "global_facts",
    "management",
    "nutrition",
    "marketing",
    "professional_accounting",
    "high_school_geography",
    "international_law",
    "moral_scenarios",
    "computer_security",
    "high_school_microeconomics",
    "professional_law",
    "medical_genetics",
    "professional_psychology",
    "jurisprudence",
    "world_religions",
    "philosophy",
    "virology",
    "high_school_chemistry",
    "public_relations",
    "high_school_macroeconomics",
    "human_sexuality",
    "elementary_mathematics",
    "high_school_physics",
    "high_school_computer_science",
    "high_school_european_history",
    "business_ethics",
    "moral_disputes",
    "high_school_statistics",
    "miscellaneous",
    "formal_logic",
    "high_school_government_and_politics",
    "prehistory",
    "security_studies",
    "high_school_biology",
    "logical_fallacies",
    "high_school_world_history",
    "professional_medicine",
    "high_school_mathematics",
    "college_medicine",
    "high_school_us_history",
    "sociology",
    "econometrics",
    "high_school_psychology",
    "human_aging",
    "us_foreign_policy",
    "conceptual_physics",
]


def read_problems(evalset_file):
    return list({task["task_id"]: task for task in stream_jsonl(evalset_file)}.values())


def stream_jsonl(filename: str):
    """
    Parses each jsonl line and yields it as a dictionary
    """
    if filename.endswith(".gz"):
        with open(filename, "rb") as gzfp:
            with gzip.open(gzfp, "rt") as fp:
                for line in fp:
                    if any(not x.isspace() for x in line):
                        yield json.loads(line)
    else:
        with open(filename, "r") as fp:
            for line in fp:
                if any(not x.isspace() for x in line):
                    yield json.loads(line)


def load_few_shot_demos(data_names, base_path):
    demos_dict = {}
    for data_name in data_names:
        if data_name == "bbh":
            all_prompts = {}
            cot_prompt_files = glob.glob(
                os.path.join(base_path, FEW_SHOT_PATH[data_name], "*.txt")
            )
            for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
                with open(cot_prompt_file, "r") as f:
                    task_name = os.path.basename(cot_prompt_file).split(".")[0]
                    task_prompt = "".join(f.readlines()[2:])
                    all_prompts[task_name] = task_prompt
            demos_dict[data_name] = all_prompts
        elif data_name == "mmlu_math":
            all_prompts = {}
            for task_name in [
                "high_school_mathematics",
                "elementary_mathematics",
                "college_mathematics",
                "abstract_algebra",
            ]:
                task_file = os.path.join(data_path, f"{task_name}_dev.csv")
                demos = []
                with open(task_file, "r") as f:
                    reader = csv.reader(f)
                    cnt = 0
                    for row in reader:
                        question = f"Question: {row[0]}\nA. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}\nAnswer: "
                        answer = row[5]
                        demos.append(
                            {
                                "input": question,
                                "target": answer,
                            }
                        )
                        cnt += 1
                        if cnt == 5:
                            break
                        all_prompts[task_name] = demos
                demos_dict[task_name] = all_prompts
        else:
            with open(FEW_SHOT_PATH[data_name], "r") as f:
                demos_str = f.read().strip()
            qas = demos_str.split("\nQ:")
            demos = []
            for qa in qas:
                if qa.strip() == "":
                    continue
                q, a = qa.split("A:")
                a = a.strip()
                q = q.replace("Q:", "").strip()
                demos.append({"input": q, "output": a})
            demos_dict[data_name] = demos
            # qas = demos_str.split("\n\n")
            # demos = []
            # for qa in qas:
            #     q, a = qa.split("A:")
            #     a = a.strip()
            #     q = q.replace("Q:", "").strip()
            #     demos.append({"input": q, "output": a})
            # demos_dict[data_name] = demos
    return demos_dict


def load_data_mc(data_name, data_base):
    if os.path.isfile(data_base) or data_name in DATA_SET:
        data_path = data_base
    else:
        data_path = os.path.join(data_base, DATA_PATH[data_name])
    save_samples = []
    if data_name == "hellaswag":
        samples = load_jsonl(data_path)
        for s in samples:
            q = s["query"].split(": ", 1)[-1]
            save_samples.append(
                {
                    "input": q,
                    "question": q,
                    "output": [" " + c for c in s["choices"]],
                    "answer": s["gold"],
                }
            )
    elif data_name in DATA_SET:
        for sub_data_name in DATA_SET[data_name]:
            cur_samples = load_data(sub_data_name, data_base)
            for s in cur_samples:
                s["type"] = sub_data_name
            save_samples.extend(cur_samples)
    for s in save_samples:
        if "type" not in s:
            s["type"] = data_name
    print(len(save_samples))
    return save_samples


def load_data(data_name, data_base):
    if os.path.isfile(data_base) or data_name in DATA_SET:
        data_path = data_base
    else:
        data_path = os.path.join(data_base, DATA_PATH[data_name])
    save_samples = []
    if "gsm" in data_name:
        samples = load_jsonl(data_path)
        if "dev" in data_name:
            samples = samples[:1000]
        for s in samples:
            answer = s["answer"].split("####")[-1].strip()
            save_samples.append(
                {"question": s["question"], "answer": answer, "steps": s["answer"]}
            )
    elif "bbh" in data_name:
        task_files = glob.glob(os.path.join(data_path, "*.json"))
        save_samples = []
        for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
            with open(task_file, "r") as f:
                task_name = os.path.basename(task_file).split(".")[0]
                if "bbh_symbol" in data_name and task_name not in [
                    "reasoning_about_colored_objects",
                    "date_understanding",
                    "penguins_in_a_table",
                ]:
                    continue
                sub_samples = json.load(f)["examples"]
                if "dev" in data_name:
                    sub_samples = sub_samples[:20]
                else:
                    sub_samples = sub_samples[20:]
                for sub_sample in sub_samples:
                    sub_sample["task_name"] = task_name
                    sub_sample["question"] = sub_sample["input"]
                    sub_sample["answer"] = sub_sample["target"]
            save_samples.extend(sub_samples)
    elif "mmlu" in data_name:
        if data_name == "mmlu_math":
            for task_name in [
                "high_school_mathematics",
                "elementary_mathematics",
                "college_mathematics",
                "abstract_algebra",
            ]:
                task_file = os.path.join(data_path, f"{task_name}_test.csv")
                with open(task_file, "r") as f:
                    reader = csv.reader(f)
                    for row in reader:
                        question = f"Question: {row[0]}\nA. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}\nAnswer: "
                        answer = row[5]
                        save_samples.append(
                            {
                                "question": question,
                                "answer": answer,
                                "task_name": task_name,
                            }
                        )
                    sub_samples = json.load(f)["examples"]
                    for sub_sample in sub_samples:
                        sub_sample["task_name"] = task_name
                        sub_sample["question"] = sub_sample["input"]
                        sub_sample["answer"] = sub_sample["target"]
                save_samples.extend(sub_samples)
        pass
    elif data_name == "MATH" or data_name == "MATH500":
        if data_name == "MATH":
            with open(data_path, "r") as f:
                samples = json.load(f)
        else:
            samples = load_jsonl(data_path)
        for s in samples:
            answer = extract_answer_math(s["solution"])
            save_samples.append(
                {
                    "question": s["problem"],
                    "answer": answer,
                    "steps": s["solution"],
                    "subtype": s["subtype"] if "subtype" in s else "subject",
                    "level": s["level"],
                }
            )
    elif "svamp" in data_name:
        with open(data_path, "r") as f:
            samples = json.load(f)
        save_samples = []
        for s in samples:
            body = s["Body"].strip()
            if not body.endswith("."):
                body += "."
            q = body + " " + s["Question"].strip()
            a = str(s["Answer"])
            if a[-2:] == ".0":
                a = a[:-2]
            a = delete_extra_zero(a)
            save_samples.append({"question": q, "equation": s["Equation"], "answer": a})
    elif "asdiv" in data_name:
        samples = load_jsonl(data_path)
        for s in samples:
            body = s["body"].strip()
            question = s["question"].strip()
            q = body + " " + question
            a = re.sub(r"\(.*?\)", "", s["answer"])
            save_samples.append({"question": q, "answer": a})
    elif "csqa" in data_name:
        samples = load_jsonl(data_path)
        if "dev" in data_name:
            samples = samples[:1000]
        for s in samples:
            choice = "Answer Choices:"
            for c in s["question"]["choices"]:
                choice += " ("
                choice += c["label"]
                choice += ") "
                choice += c["text"]
            q = s["question"]["stem"].strip() + " " + choice
            a = s["answerKey"]
            save_samples.append({"question": q, "answer": a})
    elif "aqua" in data_name:
        samples = load_jsonl(data_path)
        for s in samples:
            choice = "(" + "(".join(s["options"])
            choice = choice.replace("(", " (").replace(")", ") ")
            choice = "Answer Choices:" + choice
            q = s["question"].strip() + " " + choice
            a = s["correct"]
            save_samples.append({"question": q, "answer": a})
    elif data_name == "sat":
        samples = load_jsonl(data_path)
        for s in samples:
            choice = "Answer Choices: "
            choice += " ".join(s["options"])
            choice = re.sub(r"\((A|B|C|D|E)\)(.*?)", r"(\1) \2", choice)
            q = s["question"].strip() + " " + choice
            a = s["label"]
            save_samples.append({"question": q, "answer": a})
    elif data_name == "arcc" or data_name == "arce":
        with open(data_path, "r", errors="ignore") as in_f:
            for line in in_f:
                item = json.loads(line.strip())
                question = item["question"]
                if len(question["choices"]) != 4:
                    continue
                choice = "Answer Choices:\n"
                for c in question["choices"]:
                    choice += "("
                    choice += c["label"]
                    choice += ") "
                    choice += c["text"]
                    choice += "\n"
                q = question["stem"].strip() + "\n" + choice.strip()
                save_samples.append(
                    {
                        "question": q,
                        "answer": item["answerKey"],
                    }
                )
    elif data_name == "humaneval":
        samples = read_problems(data_path)
        for s in samples:
            s["question"] = s["prompt"]
            s["type"] = "humaneval"
        return samples
    elif data_name in [
        "multiarith",
        "addsub",
        "singleeq",
        "multiarith_pal",
        "addsub_pal",
        "singleeq_pal",
        "multiarith_tora",
        "addsub_tora",
        "singleeq_tora",
        "strategyqa",
        "strategyqa_dev",
        "coin_flip",
        "last_letters",
    ]:
        with open(data_path) as f:
            if data_name in [
                "coin_flip",
                "last_letters",
                # "strategyqa",
                # "strategyqa_dev",
            ]:
                json_data = json.load(f)["examples"]
            else:
                json_data = json.load(f)
            if "dev" in data_name:
                json_data = json_data[:200]
            for idx, line in enumerate(json_data):
                if "strategyqa" in data_name:
                    # q = line["input"].strip()
                    # a = int(line["target_scores"]["Yes"])
                    q = line["question"].strip()
                    a = int(line["answer"])
                    if a == 1:
                        a = "yes"
                    else:
                        a = "no"
                    id = "temp_{}".format(idx)
                elif data_name in ["coin_flip", "last_letters"]:
                    q = line["question"]
                    a = line["answer"]
                    id = "temp_{}".format(idx)
                elif data_name in [
                    "multiarith",
                    "addsub",
                    "singleeq",
                    "multiarith_pal",
                    "addsub_pal",
                    "singleeq_pal",
                    "multiarith_tora",
                    "addsub_tora",
                    "singleeq_tora",
                ]:
                    q = line["sQuestion"].strip()
                    a = str(line["lSolutions"][0])
                    id = "temp_{}".format(idx)
                else:
                    raise ValueError("not support dataset: {}".format(data_name))
                save_samples.append({"question": q, "answer": a, "id": id})
    elif data_name == "triviaqa":
        with open(data_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                assert len(row) == 2
                question = row[0]
                answers = eval(row[1])
                # if split == 'test':
                #     answers = answers[0]
                save_samples.append({"question": question, "answer": answers})
    elif data_name == "nq":
        with open(data_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                assert len(row) == 2
                question = row[0]
                answers = eval(row[1])
                # answers = answers[0]
                save_samples.append({"question": question, "answer": answers})
    elif data_name == "refine":
        save_samples = load_jsonl(data_path)
    elif data_name in DATA_SET:
        for sub_data_name in DATA_SET[data_name]:
            cur_samples = load_data(sub_data_name, data_base)
            for s in cur_samples:
                s["type"] = sub_data_name
            save_samples.extend(cur_samples)
    for s in save_samples:
        if "type" not in s:
            s["type"] = data_name
    print(len(save_samples))
    return save_samples

