import re
import json
import os
import pprint
import asyncio
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool
import multiprocessing as mp


from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt, LLM
from src.evol.openai_utils import num_tokens_from_messages
from src.utils.data_utils import extract_answer_math, extract_answer_number
from src.utils.code_utils import execute_tora
from src.utils.file_utils import load_jsonl, load_jsonl_ml
from src.utils.math_utils import compare_ans, vote


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--strategy_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    parser.add_argument("--num_process", default=1, type=int)
    parser.add_argument("--output_path", default=None, type=str)
    parser.add_argument("--batch_size", default=10, type=int)
    args = parser.parse_args()
    return args


def load_prompt(prompt_path):
    with open(prompt_path, "r", encoding="utf-8") as fp:
        prompt = fp.read().strip()
    return prompt


def stop_tora(result):
    if "\\boxed" in result:
        return True
    return False


def clean_gsm_output(output):
    expr_pattern = r"<<(.*?)>>"
    output = re.sub(expr_pattern, "", output)
    output, answer = output.split("####")
    return f"{output.strip()}\nThe answer is {answer.strip()}"


def extract_judgement(s):
    ans = s.split("\\underline")
    if len(ans) == 1:
        return s
    ans = ans[-1]
    if len(ans) == 0:
        return ""
    try:
        if ans[0] == "{":
            stack = 1
            a = ""
            for c in ans[1:]:
                if c == "{":
                    stack += 1
                    a += c
                elif c == "}":
                    stack -= 1
                    if stack == 0:
                        break
                    a += c
                else:
                    a += c
        else:
            a = ans.strip()
    except:
        return ""
    return a


def clean_feedback(feedback):
    feedback = feedback.split("## Feedback")[-1].strip()
    feedback = feedback.split("##")[0].strip()
    return feedback


def parse_output(output):
    if "## Refined Solution" in output:
        try:
            feedback, solution = output.split("## Refined Solution")
        except:
            return None, None
        solution = solution.strip().strip(":").strip()
    else:
        if "\\underline{correct}" not in output.replace(" ", ""):
            return None, None
        feedback, solution = output, ""
    feedback = clean_feedback(feedback)
    return feedback, solution


def main(args, samples, idx):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    if idx <= 0:
        print(prompt)
    os.makedirs(f"result/{args.model}/{args.dataset}", exist_ok=True)
    if args.output_path is None:
        output_path = f"result/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}-train_specific.jsonl"
    else:
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
        output_path = args.output_path
    print("%" * 30, "Tora", "%" * 30)
    print("Start PID %d and save to %s" % (os.getpid(), output_path))
    print(len(samples))

    if idx != -1:
        output_path = output_path.replace(".jsonl", f"_{idx}.jsonl")
    save_samples, scores = [], []
    # samples = samples[args.num_skips :]
    if args.num_skips != 0:
        ref_samples = load_jsonl_ml(output_path)
        ref_questions = [rs["question"] for rs in ref_samples]
        ref_solutions = [rs["old_generation"][0] for rs in ref_samples]
        samples = [
            s
            for s in samples
            if s["question"] not in ref_questions
            or s["generation"][0] not in ref_solutions
        ]
        print(len(samples))
    llm = LLM()
    batch_size = args.batch_size
    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        for i in tqdm(range(0, len(samples), batch_size)):
            batch_samples = samples[i : i + batch_size]
            batch_messages = [
                [
                    {
                        "role": "system",
                        "content": "You are a helpful expert for math problem solving.",
                    },
                    {
                        "role": "user",
                        "content": prompt.replace("{question}", s["question"])
                        .replace("{solution}", s["generation"][0])
                        .replace("{gold}", s["steps"]),
                    },
                ]
                for s in batch_samples
            ]
            max_tokens = args.max_tokens
            model = args.model
            for m in batch_messages:
                num_tokens = num_tokens_from_messages(m)
                while num_tokens + max_tokens > 4040 and max_tokens > 512:
                    max_tokens -= 100
                if max_tokens < 512:
                    model = "gpt-3.5-tubro-16k"
            batch_outputs = asyncio.run(
                llm.achat(
                    batch_messages,
                    model=model,
                    stop=["---"],
                    max_tokens=max_tokens,
                    temperature=args.temperature,
                    num_beams=args.num_seqs,
                )
            )
            print(batch_outputs)
            for s, outputs in zip(batch_samples, batch_outputs):
                feedbacks, solutions, pred_anss = [], [], []
                verify_scores = []
                for o in outputs:
                    feedback, solution = parse_output(o)
                    if feedback is None:
                        continue
                    judgement = extract_judgement(feedback)
                    if len(judgement.strip()) == 0:
                        continue
                    feedbacks.append(feedback)
                    solutions.append(solution)
                    if judgement == "correct":
                        pred_anss.append(extract_answer_number(s["generation"][0]))
                        verify_scores.append(int(s["score"] == 1))
                    else:
                        pred_anss.append(extract_answer_math(solution))
                        verify_scores.append(int(s["score"] == 0))
                label_ans = s["label_answer"]
                # score = 0
                score, verify_score = 0, 0
                for p, v in zip(pred_anss, verify_scores):
                    if compare_ans(p, label_ans):
                        score = 1
                    if v == 1:
                        verify_score = 1
                scores.append(score)
                save_sample = s
                save_sample["old_generation"] = save_sample["generation"]
                save_sample["old_score"] = save_sample["score"]
                save_sample["generation"] = solutions
                save_sample["feedback"] = feedbacks
                save_sample["pred_answers"] = pred_anss
                save_sample["label_answer"] = label_ans
                save_sample["verify_scores"] = verify_scores
                save_sample["verify_score"] = verify_score
                save_sample["score"] = score
                save_samples.append(save_sample)
                f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
                f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")


if __name__ == "__main__":
    args = parse_args()
    # samples = load_data(args.dataset, args.data_path)
    samples = load_jsonl(args.data_path)
    if args.num_process == 1:
        main(args, samples, idx=-1)
    else:
        num_each_split = int(len(samples) / args.num_process)
        p = mp.Pool(args.num_process)
        for idx in range(args.num_process):
            start = idx * num_each_split
            if idx == args.num_process - 1:
                end = max((idx + 1) * num_each_split, len(samples))
            else:
                end = (idx + 1) * num_each_split
            split_data = samples[start:end]
            p.apply_async(
                main,
                args=(
                    args,
                    split_data,
                    idx,
                ),
            )
        p.close()
        p.join()
        print("All of the child processes over!")
