import sys
sys.path.append("..")
from backend.gpt import query_gpt4

import os
import yaml
import json
import numpy as np
from openai import OpenAI
from pprint import pprint

# load global config
file_path = os.path.dirname(__file__)
project_path = os.path.dirname(file_path)
global_config = yaml.safe_load(open(os.path.join(project_path, "config/global.yaml"), "r"))


OPENAI_API_KEY = global_config.get("backend").get("openai_api_key")
BASE_URL = global_config.get("backend").get("base_url", None)
MAX_RETRY_TIMES = global_config.get("agent").get("max_query_retry_times", 10)

client = OpenAI(
    api_key=OPENAI_API_KEY,
    base_url=BASE_URL,
)

def get_embedding(text, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return client.embeddings.create(input=[text], model=model).data[0].embedding

def get_cosine_similarity(embeddingi, embeddingj):
    embeddingi = np.array(embeddingi)
    embeddingj = np.array(embeddingj)
    cos_sim = embeddingi.dot(embeddingj) / (np.linalg.norm(embeddingi) * np.linalg.norm(embeddingj))
    return cos_sim

file_prefix = "Needle_GPTV3.5_DoubleMind_"

llm_judge_prompt = """
You are a experienced human labeler for reading comprehension task.
Given a ground truth answer and a model prediction,
you have to judge whether the model prediction is correct.
The question is {}.
The ground truth answer is {}.
The model prediction is {}.

return 1 if the model prediction is correct else 0.
the model prediction may be a little different on the expression, as long as the meaning or key entity is correct, the answer can be regarded as correct.
ONLY RETURN THE NUMBER.
"""

evaluated_set = set()
if os.path.exists(file_prefix + "result"):
    with open(file_prefix + "result", "r") as f:
        for line in f:
            evaluated_set.add(line.strip().split("\t")[0])

fw = open(file_prefix + "result", "a")
for filename in os.listdir("./Needle"):
    if file_prefix in filename and filename.endswith(".json"):
        print(filename)
        data = json.loads(open(os.path.join("Needle",filename), "r").read())
        if str(data['sample_id']) in evaluated_set:
            continue
        # pprint(data)
        question = data['task_prompt'].replace("\n","")
        ground_truth = data['answer'].replace("\n","")
        predict = data['predicted_answer']
        if not predict:
            predict = "None"
            judge_result = "0"
        else:
            predict = predict.replace("\n","")
            query = llm_judge_prompt.format(question, ground_truth, predict)
            judge_result = query_gpt4(query)
        fw.write("\t".join([str(data['sample_id']), ground_truth, predict, judge_result]) + "\n")








# for evaluation_result in os.listdir("./exp"):
#     if "evaluation_result_" not in evaluation_result:
#         continue
#     print(evaluation_result)
#     totalscore = 0.0
#     idx = 0
#     fw = open("./exp/" + evaluation_result.replace(".txt", "_emb.txt"), "w")
#     for sample in open("./exp/" + evaluation_result, "r").read().split("----------")[:30]:
#         idx += 1
#         final_answer_emb = get_embedding(final_answer)
#         ground_truth_emb = get_embedding(ground_truth)
#         score = get_cosine_similarity(final_answer_emb, ground_truth_emb)
#         print(score)
#         fw.write(str(score) + "\n")
#         totalscore += score
#     fw.write("total score: " + str(totalscore / 30))
#     fw.close()