# -*- coding:utf-8 -*-
import os
import re
import sys
import time
import json
import copy
import random
import argparse
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from openai import OpenAI
from collections import Counter

sys.path.append("..")
from base_utils import experts_task
from data_utils import load_dataset
from utils import compute_metrics, your_api_key

MAX_API_RETRY = 200
REQ_TIME_GAP = 4

RESULT_PATH = "../results"
DATASET_PATH = "../data"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Value Evaluation")

    parser.add_argument("--dataset", default="moral_choice", type=str, help="Dataset to evaluate (beavertails, denevil, moral_choice, value_fulcra)")
    parser.add_argument("--data_version", default="original", type=str, help="Data version to use (original, pairwise, augmented)")
    parser.add_argument("--model_name", default="openai/gpt-3.5-turbo", type=str, help="Model to evaluate")
    parser.add_argument("--prompt_method", default="vanilla", type=str, help="Prompt method to use, vanilla means just prompts.")
    parser.add_argument("--with_definition", default="True", type=str, help="Whether to include the basic/prior definition in the prompt.")
    parser.add_argument("--few_shot_num", default=3, type=int, help="Number of few-shot examples to use")
    parser.add_argument("--load_local_ckpt", default=None, type=str, help="Load a local checkpoint for the model")

    parser.add_argument("--train_split", default="small", type=str, help="Train split to use, small / large")
    parser.add_argument("--eval_max_tokens", default=100, type=int, help="Max tokens for evaluation")
    parser.add_argument("--eval_temp", default=1.0, type=float, help="Temperature for sampling")
    parser.add_argument("--eval_top_p", default=1.0, type=float, help="Top-P parameter for top-p sampling")
    parser.add_argument("--eval_repeats", default=1, type=int, help="Number of repeats for evaluation on one sample")
    parser.add_argument("--eval_num_samples", default=10, type=int, help="Number of samples to evaluate on")
    parser.add_argument("--batch_size", default=8, type=int, help="Batch size for evaluation")
    parser.add_argument("--gpu_num", default=8, type=int, help="GPU number to use for evaluation")
    parser.add_argument("--remove_value", default=None, type=str, help="Remove a specific value from the training data")

    parser.add_argument("--evaluate", action="store_true", help="Run evaluation on the dataset with the model and prompt method.")
    return parser.parse_args()

task_funct = {
    "beavertails": [experts_task.gen_prompt_aspectu_beaver, experts_task.gen_prompt_aspect_beaver, experts_task.gen_prompt_init_beaver, experts_task.gen_prompt_beaver],
    "value_fulcra": [experts_task.gen_prompt_aspectu_value, experts_task.gen_prompt_aspect_value, experts_task.gen_prompt_init_value, experts_task.gen_prompt_value],
    "denevil": [experts_task.gen_prompt_aspectu_denevil, experts_task.gen_prompt_aspect_denevil, experts_task.gen_prompt_init_denevil, experts_task.gen_prompt_denevil],
}

def call_openai_gpt(prompt, sys_prompt="", model_name="gpt-4"):
    api_key = your_api_key
    client = OpenAI(api_key=api_key)
    if sys_prompt == "":
        message = [{"role": "user", "content": prompt}]
    else:
        message = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": prompt},
        ]
    model_name = model_name.split("/")[-1]
    max_tokens = 2048
    if model_name == "gpt-35-turbo" or model_name == "openai/gpt-3.5-turbo":
        model_name = "gpt-3.5-turbo"
        max_tokens = 1024
    elif "gpt-4" in model_name:
        model_name = "gpt-4-1106-preview"
    response = client.chat.completions.create(
        model = model_name,
        messages = message,
        temperature = 1.0,
        top_p = 1.0,
        max_tokens = max_tokens,
        frequency_penalty = 0.0,
        presence_penalty = 0.0,
    )
    answer = response.choices[0].message.content.strip()

    return answer

def aspect_layer(args, value, value_details, asp_num):
    if asp_num == -1:
        prompt = task_funct[args.dataset][0](context, action, value, value_details, asp_num)
    else:
        prompt = task_funct[args.dataset][1](context, action, value, value_details, asp_num)
    response = call_openai_gpt(prompt, model_name=args.model_name)
    aspects = [asp.split(":", 1)[0] for asp in response.split("\n")]
    return aspects

def init_layer(args, context, action, value, value_details, aspects):
    answers = {}
    contents = {}
    for idx in range(len(aspects)):
        neuro_name = "m" + str(idx+1)
        sys_prompt, prompt = task_funct[args.dataset][2](context, action, value, value_details, aspects[idx])

        response = call_openai_gpt(prompt, sys_prompt, model_name=args.model_name)
        response = response.replace("<start output>", "").replace("<end output>", "").strip()

        answers[neuro_name] = response.rsplit(": ", 1)[-1].strip()
        contents[neuro_name] = response

    return answers, contents

def single_layer(args, context, action, value, value_details, aspects, M):
    neuro_num = len(M)
    window_size = 2
    answers = {}
    contents = {}

    for idx in range(neuro_num):
        neuro_name = "m" + str(idx+1)
        own = [copy.deepcopy(M[neuro_name])] # a list
        others = []

        asps = []
        start_idx = max(idx-window_size+1, 0)
        end_idx = min(neuro_num, idx+window_size)
        for ii in range(start_idx, end_idx):
            asps.append(aspects[ii])
            if ii != idx:
                others.append(copy.deepcopy(M["m"+str(ii+1)]))
            
        sys_prompt, prompt, union_aspects = task_funct[args.dataset][3](context, action, value, value_details, asps, own, others)
        response = call_openai_gpt(prompt, sys_prompt, model_name=args.model_name)
        response = response.replace("<start output>", "").replace("<end output>", "").strip()

        answers[neuro_name] = response.rsplit(": ", 1)[-1].strip()
        contents[neuro_name] = response
    
    return answers, contents

def widedeep_eval(input_, args):
    context, action, value, value_details, label = input_
    # if args.limit_neuro == "True":
    #     num_neuro, num_layer = 2, 2
    # else:
    #     num_neuro, num_layer = -1, 2
    num_neuro, num_layer = 3, 2

    current_layer = 0
    hist_answers, hist_contents = {}, {}

    aspects = aspect_layer(args, value, value_details, num_neuro)
    num_neuro = len(aspects)
    answers, contents = init_layer(args, context, action, value, value_details, aspects)
    hist_contents["l"+str(current_layer+1)] = contents
    hist_answers["l"+str(current_layer+1)] = answers
    M = copy.deepcopy(contents)

    hist_answers["l"+str(current_layer+1)] = answers
    while current_layer < num_layer-1:
        current_layer += 1
        answers, contents = single_layer(args, context, action, value, value_details, aspects, contents)
        hist_contents["l"+str(current_layer+1)] = contents
        hist_answers["l"+str(current_layer+1)] = answers
        M = copy.deepcopy(contents)

    answers_list = []

    for layer in hist_answers:
        for neuro in hist_answers[layer]:
            answers_list.append(hist_answers[layer][neuro])
    
    yes_no_count = Counter(answers_list)
    final_answer = "Yes" if yes_no_count["Yes"] > yes_no_count["No"] else "No"
    return final_answer

if __name__ == "__main__":
    args = parse_arguments()
    print("Running with args: ", args)
    
    test_data = load_dataset(args.dataset, "test", version=args.data_version, args=args)
    print("Test data loaded: ", len(test_data))

    results = []
    result_path = f"{RESULT_PATH}/{args.dataset}/{args.data_version}/{args.train_split}/{args.model_name.split('/')[-1]}"
    if os.path.exists(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv"):
        results = pd.read_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv").to_dict('records')
        for r in results:
            r.pop('Unnamed: 0', None)

    for idx, scenario in tqdm(test_data.iterrows(), desc=f"Evaluating on Dataset: {args.dataset}, Method: {args.prompt_method}, Model: {args.model_name}"):
        if idx == args.eval_num_samples:
            break
        if idx < len(results):
            continue
        context, action, value, value_details, label = scenario["scenario"], scenario["action"], scenario["value"], scenario["value_details"], scenario["label"]

        answer = widedeep_eval([context, action, value, value_details, label], args)
        result_base = {
            "value": value,
            "label": label,
            "answer": answer,
            "current_repeat": 0,
            "eval_repeats": 1,
        }
        results.append(result_base)
        if idx % 50 == 0:
            results_df = pd.DataFrame(results)    
            results_df.to_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")

    results_df = pd.DataFrame(results)    
    results_df.to_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")

    print("Evaluation results: ", results_df)

    ### Step 3: Compute the metrics
    metrics = compute_metrics(results_df)
    print(f"Metrics for evaluation on {args.dataset} with model {args.model_name} and prompt method {args.prompt_method}:")
    for metric, value in metrics.items():
        print(f"{metric}: {value}")