import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from collections import Counter

sys.path.append("..")
from utils import compute_metrics
from base_utils import faireval_prompts
from data_utils import load_dataset
from widedeep import call_openai_gpt

RESULT_PATH = "../../results"
DATASET_PATH = "../../data"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Value Evaluation")

    parser.add_argument("--dataset", default="moral_choice", type=str, help="Dataset to evaluate (beavertails, denevil, moral_choice, value_fulcra)")
    parser.add_argument("--data_version", default="original", type=str, help="Data version to use (original, pairwise, augmented)")
    parser.add_argument("--model_name", default="openai/gpt-3.5-turbo", type=str, help="Model to evaluate")
    parser.add_argument("--prompt_method", default="vanilla", type=str, help="Prompt method to use, vanilla means just prompts.")
    parser.add_argument("--with_definition", default="True", type=str, help="Whether to include the basic/prior definition in the prompt.")
    parser.add_argument("--few_shot_num", default=3, type=int, help="Number of few-shot examples to use")
    parser.add_argument("--num_agents", default=3, type=int, help="Number of agents in the chat")

    parser.add_argument("--train_split", default="small", type=str, help="Train split to use, small / large")
    parser.add_argument("--eval_max_tokens", default=100, type=int, help="Max tokens for evaluation")
    parser.add_argument("--eval_temp", default=1.0, type=float, help="Temperature for sampling")
    parser.add_argument("--eval_top_p", default=1.0, type=float, help="Top-P parameter for top-p sampling")
    parser.add_argument("--eval_repeats", default=1, type=int, help="Number of repeats for evaluation on one sample")
    parser.add_argument("--eval_num_samples", default=10, type=int, help="Number of samples to evaluate on")
    parser.add_argument("--batch_size", default=8, type=int, help="Batch size for evaluation")
    parser.add_argument("--gpu_num", default=8, type=int, help="GPU number to use for evaluation")
    parser.add_argument("--remove_value", default=None, type=str, help="Remove a specific value from the training data")

    parser.add_argument("--evaluate", action="store_true", help="Run evaluation on the dataset with the model and prompt method.")
    return parser.parse_args()

def eval_one_example(scenario, args):
    context, action, value, value_details, label = scenario["scenario"], scenario["action"], scenario["value"], scenario["value_details"], scenario["label"]

    answer_list = []
    if args.dataset == "denevil":
        prompt = faireval_prompts.prompt_template[args.dataset].format(value=value_details, scenario=context).replace("\t", "").replace("    ", "")
    else:
        prompt = faireval_prompts.prompt_template[args.dataset].format(value=value_details, context=context, action=action).replace("\t", "").replace("    ", "")

    for _ in range(args.eval_repeats):
        response = call_openai_gpt(prompt, model_name=args.model_name)
        response = response.replace("<start output>", "").replace("<end output>", "").strip()
        answer = response.rsplit("Your decision: ")[-1].strip()
        answer_list.append(answer)
    
    yes_no_count = Counter(answer_list)
    answer = "Yes" if yes_no_count["Yes"] > yes_no_count["No"] else "No"
    return answer, "; ".join(answer_list)

if __name__ == "__main__":
    args = parse_arguments()
    print("Running with args: ", args)
    
    test_data = load_dataset(args.dataset, "test", version=args.data_version, args=args)
    print("Test data loaded: ", len(test_data))

    results = []
    result_path = f"{RESULT_PATH}/{args.dataset}/{args.data_version}/{args.train_split}/{args.model_name.split('/')[-1]}"
    if os.path.exists(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv"):
        results = pd.read_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv").to_dict('records')
        for r in results:
            r.pop('Unnamed: 0', None)

    for idx, scenario in tqdm(test_data.iterrows(), desc=f"Evaluating on Dataset: {args.dataset}, Method: {args.prompt_method}, Model: {args.model_name}"):
        if idx == args.eval_num_samples:
            break
        if idx < len(results):
            continue

        answer, answer_list = eval_one_example(scenario, args)
        
        result_base = {
            "value": scenario["value"],
            "label": scenario["label"],
            "answer": answer,
            "answer_list": answer_list,
            "current_repeat": 0,
            "eval_repeats": 1,
        }
        results.append(result_base)

    results_df = pd.DataFrame(results)    
    results_df.to_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")

    print("Evaluation results: ", results_df)

    ### Step 3: Compute the metrics
    metrics = compute_metrics(results_df)
    print(f"Metrics for evaluation on {args.dataset} with model {args.model_name} and prompt method {args.prompt_method}:")
    for metric, value in metrics.items():
        print(f"{metric}: {value}")