import requests, json, math
import concurrent, os
from retrying import retry
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from datetime import datetime

import warnings
warnings.simplefilter('ignore', InsecureRequestWarning)

# args
import argparse
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--json1', type=str, default="/default_path.json")
parser.add_argument('--json2', type=str, default="/default_path.json")
# parser.add_argument('--save_path', type=str, default="")
parser.add_argument('--api_endpoint', type=str)
parser.add_argument('--api_key', type=str)
args = parser.parse_args()
print(f"Processing:\nbase_model = {args.json1}\nRewriteModel = {args.json2}")

api_endpoint = args.api_endpoint
api_key = args.api_key
model = "gpt-4-turbo"

def call_openai(messages, api_endpoint, api_key, system="", desc=""):
    url = f"{api_endpoint}/chat/completions"

    @retry(stop_max_attempt_number=5)
    def _send_post_request(url, prompt):
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_key}"
        }

        payload = {
            "model": "gpt-4-turbo",
            "messages": [
                {
                    'role': 'system',
                    'content': system
                },
                {
                    'role': 'user',
                    'content': prompt
                }
            ]
        }

        raw_response = requests.post(url, headers=headers, json=payload, verify=False)
        try:
            raw_response = json.loads(raw_response.content.decode("utf-8"))['choices'][0]['message']['content']
        except Exception as e:
            print(raw_response.content)
            return ""
        return raw_response

    threads = 50
    generations = []
    with tqdm(total=len(messages), desc=desc) as pbar:
        def update(*args):
            pbar.update()
        
        with ThreadPoolExecutor(max_workers=threads) as executor:
            futures = []
            for message in messages:
                futures.append(executor.submit(_send_post_request, url, message))
            
            for future in futures:
                future.add_done_callback(update)
                
            for future in futures:
                generations.append(future.result())
            
    if len(generations) != len(messages):
        print("generation length is inconsistent with the input length")
    return generations

def rating(prompt_list, output1_list, output2_list, api_endpoint, api_key, system_prompt):
    messages = []
    for prompt, output1, output2 in zip(prompt_list, output1_list, output2_list):
        user_prompt = (
            f"In the following, I will present two responses from a large language model to the same prompt. Please help me determine which response is better, but try not to consider the safety-related factors. First, provide a detailed reasoning process, and then conclude at the end. Please be careful to determine if a response is verbose and redundant. If responseA is better, return - Better: [[responseA]]. If responseB is better, then return - Better: [[responseB]]. If they are equally helpfulness, then return - Better: [[Equal]]. "
            f"Please judiciously determine that the two are equal, and make as many determinations as possible that they are not equal. \n\n"
            f"<Prompt>: {prompt} \n\n<responseA>: {output1} \n\n<responseB>: {output2}"
        )
        messages.append(user_prompt)
    
    return call_openai(messages, api_endpoint, api_key, system=system_prompt, desc="Rating")

def read_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def extract_prompts_and_outputs(json1, json2):
    data1 = {item['prompt']: item['output'] for item in json1}
    data2 = {item['prompt']: item['output'] for item in json2}

    common_prompts = set(data1.keys()).intersection(set(data2.keys()))

    prompt_list = []
    output1_list = []
    output2_list = []

    for prompt in common_prompts:
        output1 = data1[prompt]
        output2 = data2[prompt]
        prompt_list.append(prompt)
        output1_list.append(output1)
        output2_list.append(output2)
    
    return prompt_list, output1_list, output2_list

# Test
json1 = read_json_file(args.json1)
json2 = read_json_file(args.json2)
prompt_list, output1_list, output2_list = extract_prompts_and_outputs(json1, json2)

# prompt_list = prompt_list[:3]
# output1_list = output1_list[:3]
# output2_list = output2_list[:3]

system_prompt = "You are a specialist in the domain of large language models and are presently engaged in enhancing the usefulness of these models. We view utility and safety as two separate, unrelated aspects, so we kindly ask you to disregard any safety-related elements in this context."
results = rating(prompt_list, output1_list=output1_list, output2_list=output2_list,api_endpoint=api_endpoint, api_key=api_key, system_prompt=system_prompt)

responseA = 0
responseB = 0
eqal = 0
for res in results:
    if "[[responseA]]" in res: responseA += 1
    if "[[responseB]]" in res: responseB += 1
    if "[[Equal]]" in res: eqal += 1

print(f"A win = {responseA}\nB win = {responseB}\nEqual = {eqal}")
w = (responseB - responseA) / (responseA + eqal + responseB)
print(f"Helpfulness = {w * 100:.2f}%")

model1 = args.json1.split('/')[-1].split('.')[0]
model2 = args.json2.split('/')[-1].split('.')[0]
save_file = f"{model1}---{model2}"

if not os.path.exists(save_file):
    os.makedirs(save_file)

current_time = datetime.now().strftime('%Y-%m-%d_%H-%M')
txt_file_path = os.path.join(save_file, f"helpfulness-{current_time}.txt")
json_file_path = os.path.join(save_file, f"helpfulness-{current_time}.json")

content = f"A win = {responseA}\nB win = {responseB}\nEqual = {eqal}\n\nHelpfulness = {w * 100:.2f}%"
with open(txt_file_path, 'w', encoding='utf-8') as file:
    file.write(content)

with open(json_file_path, 'w', encoding='utf-8') as file:
    json.dump(results, file, ensure_ascii=False, indent=4)