import os
import json
import random
import argparse
import pandas as pd
from tqdm import tqdm
from data_utils import load_dataset
from prompt_utils import PromptCalibration, KeyConceptExtraction, few_shot_train_data, AllureCalibration, geval_generate_cot
from utils import set_random_seed
from collections import defaultdict

RESULT_PATH = "../results"
DATASET_PATH = "../data"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Value Calibration")

    parser.add_argument("--dataset", default="beavertails", type=str, help="Dataset to evaluate (beavertails, denevil, moral_choice, value_fulcra)")
    parser.add_argument("--data_version", default="original", type=str, help="Data version to use (original, pairwise, augmented)")
    parser.add_argument("--model_name", default="openai/gpt-3.5-turbo", type=str, help="Model to evaluate")
    parser.add_argument("--prompt_method", default="vanilla", type=str, help="Prompt method to use, vanilla means just prompts.")
    parser.add_argument("--with_definition", default="True", type=str, help="Whether to include the basic/prior definition in the prompt.")
    parser.add_argument("--train_split", default="train_small", type=str, help="Train split to use, small / large")
    parser.add_argument("--key_concept_stages", default="init,extract,extend,mapping")
    parser.add_argument("--map_scope", default="value", type=str, help="Mapping scope for key concepts, value (-specific) or all")
    
    parser.add_argument("--eval_max_tokens", default=100, type=int, help="Max tokens for evaluation")
    parser.add_argument("--eval_temp", default=1.0, type=float, help="Temperature for sampling")
    parser.add_argument("--eval_top_p", default=1.0, type=float, help="Top-P parameter for top-p sampling")
    parser.add_argument("--eval_num_samples", default=10, type=int, help="Number of samples to evaluate on")
    parser.add_argument("--remove_value", default=None, type=str, help="Remove a specific value from the training data")    
    parser.add_argument("--evaluate", action="store_true", help="Run evaluation on the dataset with the model and prompt method.")

    ## hyperparameters for prompt calibration method
    parser.add_argument("--monte_carlo_trial", default=3, type=int, help="Number of monte carlo trials")
    # few_shot_list
    parser.add_argument("--few_shot_list", nargs='+', type=int, help="Few shot list for calibration")

    ## hyperparameters for key concept extraction method
    parser.add_argument("--concept_num", default=5, type=int, help="Maximum number of key concepts to extract")
    parser.add_argument("--batch_size", default=5, type=int, help="Batch size for key concept extraction")
    parser.add_argument("--cluster_train_data", default="False", type=str, help="Cluster the training data for key concept extraction")
    parser.add_argument("--with_result", default="True", type=str, help="Whether use the result when encoding key concepts.")
    parser.add_argument("--top_concept_num", default=2, type=int, help="number of key concepts mapped for each concepts extracted by gpt-4")
    return parser.parse_args()

def number_key_concepts(key_concepts, target_value=None):
    if target_value == None:
        for key in key_concepts:
            current_concepts = key_concepts[key]
            for idx, concept in enumerate(current_concepts):
                if concept.startswith(f"{idx+1}. "):
                    continue
                concept = concept.split(". ", 1)[-1]
                current_concepts[idx] = f"{idx + 1}. {concept}"
            key_concepts[key] = current_concepts
    else:
        current_concepts = key_concepts[target_value]
        for idx, concept in enumerate(current_concepts):
            if concept.startswith(f"{idx+1}. "):
                continue
            concept = concept.split(". ", 1)[-1]
            current_concepts[idx] = f"{idx + 1}. {concept}"
        key_concepts[target_value] = current_concepts
    return key_concepts

if __name__ == "__main__":
    set_random_seed(2024)
    args = parse_arguments()
    print("Running with args: ", args)
    result_path = f"{RESULT_PATH}/{args.dataset}/{args.prompt_method}/{args.model_name.split('/')[-1]}"
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    test_data = load_dataset(args.dataset, "test", version=args.data_version, args=args)   
    train_data = load_dataset(args.dataset, args.train_split, version=args.data_version, args=args)
    print("Test data loaded: ", len(test_data), "Train data loaded: ", len(train_data))
    
    if args.prompt_method == "calibration":
        if os.path.exists(f"{result_path}/refined_value_definition_{args.train_split}.json"):
            with open(f"{result_path}/refined_value_definition_{args.train_split}.json", "r") as f:
                refined_value_definition = json.load(f)
        else:
            refined_value_definition = {}

        for target_value in sorted(train_data["value"].unique().tolist()):
            if target_value in refined_value_definition:
                continue
            calibrator = PromptCalibration(args, train_data, target_value)

            print("Calibrating for value: ", target_value)
            print("step 1: calibrate draft")
            calibrator.calibrate_draft()

            print("step 2: calibrate revisit")
            calibrator.calibrate_revisit()

            print("step 3: calibrate apply")
            calibrated_rule = calibrator.calibrate_apply()
            refined_value_definition[target_value] = calibrated_rule

            print(f"calibrated rule for value {target_value}: {calibrated_rule}")
            print("=====================================\n")

            with open(f"{result_path}/refined_value_definition_{args.train_split}.json", "w") as f:
                json.dump(refined_value_definition, f)
        
        print("Refined value definition: ")
        for value, rule in refined_value_definition.items():
            print(f"Value: {value}, Rule: {rule}")
    
    if args.prompt_method == "allure":
        if os.path.exists(f"{result_path}/allure_{args.train_split}.json"):
            with open(f"{result_path}/allure_{args.train_split}.json", "r") as f:
                allure_example_memory = json.load(f)
        else:
            allure_example_memory = {}
        
        print("Find Allure examples for each value...")
        allure = AllureCalibration(args, train_data)
        for target_value in sorted(train_data["value"].unique().tolist()):
            if target_value in allure_example_memory:
                continue
            allure_example_memory[target_value] = allure.filter_error_examples(target_value)
            with open(f"{result_path}/allure_{args.train_split}.json", "w") as f:
                json.dump(allure_example_memory, f)
    
    if args.prompt_method == "g-eval":
        print("Generate evaluation data for g-eval...")
        geval_generate_cot(args)
    
    elif args.prompt_method == "key_concepts":       
        # step 1. extract key concepts from the training samples        
        if "extract_separate," in args.key_concept_stages:
            print("Extracting key concepts from separate examples ...")
            if os.path.exists(f"{result_path}/key_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json"):
                with open(f"{result_path}/key_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "r") as f:
                    key_concepts_extract_separate = json.load(f)

                with open(f"{result_path}/example_2_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "r") as f:
                    example_2_concepts = json.load(f)
            else:
                key_concepts_extract_separate = {}
                example_2_concepts = {}

            for target_value in sorted(train_data["value"].unique().tolist()):
                print("Extracting key concepts for value: ", target_value)
                if target_value in key_concepts_extract_separate:
                    continue
                concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value)
                key_concepts_extract_separate[target_value], example_2_concept_split = concept_extractor.extract_key_concepts_separate()

                key_concepts_extract_separate = number_key_concepts(key_concepts_extract_separate, target_value)
                with open(f"{result_path}/key_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "w") as f:
                    json.dump(key_concepts_extract_separate, f)
                
                example_2_concept_split = number_key_concepts(example_2_concept_split)
                example_2_concepts = {**example_2_concepts, **example_2_concept_split}
                with open(f"{result_path}/example_2_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "w") as f:
                    json.dump(example_2_concepts, f)

            map_source = "extract_separate"
            if not os.path.exists(f"{result_path}/mapped_concepts_train_{args.train_split}_{map_source}.json"):
                mapped_concepts = []
                concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value)
                for idx, row in tqdm(train_data.iterrows(), desc="Mapping key concepts to train samples..."):
                    current = f"Scenario: {row['scenario']}\tAction: {row['action']}\tValue: {row['value']}\tDecision: {row['label']}"
                    if current in example_2_concepts and len(example_2_concepts[current]) > 0:
                        concepts = example_2_concepts[current]
                        mapped_concepts.append(concepts[:])
                    else:
                        print("not found or no extrated features: ", idx)
                        concepts = concept_extractor.extract_key_concepts_individual(row)
                        mapped_concepts.append(concepts[:])
                        example_2_concepts[current] = concepts[:]  # [1. concept1, 2. concept2, ...]
                        for concept in concepts:
                            concept = concept.split(". ", 1)[-1]
                            concept = "{}. {}".format(len(key_concepts_extract_separate[row["value"]])+1, concept)
                            key_concepts_extract_separate[row["value"]].append(concept)

                with open(f"{result_path}/mapped_concepts_train_{args.train_split}_{map_source}.json", "w") as f:
                    json.dump(mapped_concepts, f)  # [[concept1, concept2, ...], [concept1, concept2, ...], ...] len(mapped_concepts) == len(train_data)
                
                with open(f"{result_path}/key_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "w") as f:
                    json.dump(key_concepts_extract_separate, f)
                
                with open(f"{result_path}/example_2_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json", "w") as f:
                    json.dump(example_2_concepts, f)
            
            key_concepts = key_concepts_extract_separate  # {"value1": [key concept list], "value2": [key concept list], ...}
            print("Load key concepts from path: ", f"{result_path}/key_concepts_{args.train_split}_extract_separate_cluster_train_{args.cluster_train_data}.json")
    
        # step 2. encode and integrate key concepts to build key concept pool    
        if "encode," in args.key_concept_stages:
            print("Encoding key concepts...")
            if os.path.exists(f"{result_path}/key_concepts_{args.train_split}_{map_source}_embed_with_result_{args.with_result}.json"):
                with open(f"{result_path}/key_concepts_{args.train_split}_{map_source}_embed_with_result_{args.with_result}.json", "r") as f:
                    key_concepts_embed = json.load(f)
            else:
                concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value)
                key_concepts_embed = concept_extractor.encode_key_concepts(key_concepts) # {"value1": [{"concept", "embed"}, {"concept", "embed"}, ...], "value2": [{"concept", "embed"}, {"concept", "embed"}, ...], ...}

                with open(f"{result_path}/key_concepts_{args.train_split}_{map_source}_embed_with_result_{args.with_result}.json", "w") as f:
                    json.dump(key_concepts_embed, f)
            print("Load key concepts embeddings from path: ", f"{result_path}/key_concepts_{args.train_split}_{map_source}_embed_with_result_{args.with_result}.json")

        if "integrate," in args.key_concept_stages:
            print("Integrating key concepts...")
            if os.path.exists(f"{result_path}/key_concepts_{args.train_split}_{map_source}_integrate_embed_with_result_{args.with_result}.json"):
                with open(f"{result_path}/key_concepts_{args.train_split}_{map_source}_integrate_embed_with_result_{args.with_result}.json", "r") as f:
                    key_concepts_integrate_embed = json.load(f)
                
                with open(f"{result_path}/integrate_dict_{args.train_split}_{map_source}_with_result_{args.with_result}.json", "r") as f:
                    key_concepts_integrate_dict = json.load(f)
            else:
                concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value=None)
                key_concepts_integrate_embed, key_concepts_integrate_dict = concept_extractor.integrate_key_concepts(key_concepts_embed)

                with open(f"{result_path}/key_concepts_{args.train_split}_{map_source}_integrate_embed_with_result_{args.with_result}.json", "w") as f:
                    json.dump(key_concepts_integrate_embed, f)
                
                with open(f"{result_path}/integrate_dict_{args.train_split}_{map_source}_with_result_{args.with_result}.json", "w") as f:
                    json.dump(key_concepts_integrate_dict, f)
            
            # revise mapped concepts for the training samples
            if not os.path.exists(f"{result_path}/mapped_concepts_train_{args.train_split}_{map_source}_integrate_with_result_{args.with_result}.json"):
                with open(f"{result_path}/mapped_concepts_train_{args.train_split}_{map_source}.json", "r") as f:
                    raw_mapped_concepts = json.load(f)
                
                mapped_concepts = []
                for idx, row in tqdm(train_data.iterrows(), desc="Mapping integrated key concepts to train samples..."):
                    value = row["value"]
                    one_mapped_concepts = []
                    for concept in raw_mapped_concepts[idx]:
                        # try:
                        concept = concept.split(". ", 1)[-1]
                        if args.with_result == "False":
                            concept = concept.split("-->", 1)[0].strip()
                        # concept_idx = int(concept.split(". ", 1)[0])
                        raw_2_integrated = key_concepts_integrate_dict[value][concept]
                        one_mapped_concepts.append(key_concepts_integrate_embed[value][raw_2_integrated]["concept"])
                        # except:
                        #     print("Error: ", concept)
                    one_mapped_concepts = list(set(one_mapped_concepts))
                    mapped_concepts.append(one_mapped_concepts)
                
                with open(f"{result_path}/mapped_concepts_train_{args.train_split}_{map_source}_integrate_with_result_{args.with_result}.json", "w") as f:
                    json.dump(mapped_concepts, f)
        
            map_source = f"{map_source}_integrate"
            key_concepts_embed = key_concepts_integrate_embed
            print("Load integrated key concepts from path: ", f"{result_path}/key_concepts_{args.train_split}_{map_source}_integrate_embed_with_result_{args.with_result}.json")

        # step 3. key concepts inference and mapping with the LLM        
        if "infer," in args.key_concept_stages:
            print("Inferring key concepts for samples...")
            concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value=None)
            if os.path.exists(f"{result_path}/inferred_concepts_test_{args.data_version}.json"):
                inferred_concepts = json.load(open(f"{result_path}/inferred_concepts_test_{args.data_version}.json"))
            else:
                inferred_concepts = []
                inferred_concepts = concept_extractor.infer_key_concepts(test_data, store_path=f"{result_path}/inferred_concepts_test_{args.data_version}.json", inferred_concepts=inferred_concepts)

                with open(f"{result_path}/inferred_concepts_test_{args.data_version}.json", "w") as f:
                    json.dump(inferred_concepts, f)
            print("Load inferred concepts from path: ", f"{result_path}/inferred_concepts_test_{args.data_version}.json")
        
        if "mapping_embed," in args.key_concept_stages:
            print("Mapping key concepts to samples based on embeddings...")
            concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value=None)

            with open(f"{result_path}/inferred_concepts_test_{args.data_version}.json", "r") as f:
                inferred_concepts = json.load(f)

            if not os.path.exists(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_value_embed_with_result_{args.with_result}.json"):
                mapped_concepts = []
                mapped_concepts = concept_extractor.embed_mapping_key_concepts(key_concepts_embed, inferred_concepts, store_path = f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json", mapped_concepts = mapped_concepts)

                with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json", "w") as f:
                    json.dump(mapped_concepts, f)
                
            with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_value_embed_with_result_{args.with_result}.json", "r") as f:
                embed_mapped_concepts = json.load(f)
                
            if not os.path.exists(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_all_embed_with_result_{args.with_result}.json"):
                mapped_concepts = []
                mapped_concepts = concept_extractor.embed_mapping_key_concepts_all(key_concepts_embed, embed_mapped_concepts, store_path = f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_all_embed_with_result_{args.with_result}.json")

                with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_all_embed_with_result_{args.with_result}.json", "w") as f:
                    json.dump(mapped_concepts, f)

            print("Mixed mapping key concepts to samples based on embeddings...")
            if not os.path.exists(f"{result_path}/mixed_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json"):
                mapped_concepts = []
                mapped_concepts = concept_extractor.mixed_mapping_key_concepts(key_concepts_embed, embed_mapped_concepts, store_path = f"{result_path}/mixed_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json")

                with open(f"{result_path}/mixed_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json", "w") as f:
                    json.dump(mapped_concepts, f)
            
            print("Save mapped concepts to path: ", f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_with_result_{args.with_result}.json")

        if "mapping_llm," in args.key_concept_stages: 
            print("Mapping key concepts to samples based on LLMs...")
            concept_extractor = KeyConceptExtraction(args, train_data, test_data, target_value=None)

            with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_value_embed_with_result_{args.with_result}.json", "r") as f:
                embed_mapped_concepts = json.load(f)
            
            if os.path.exists(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_llm_with_result_{args.with_result}.json"):
                llm_mapped_concepts = json.load(open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_llm_with_result_{args.with_result}.json"))
            else:
                if not os.path.exists(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_value_embed_llm_with_result_{args.with_result}.json"):
                    llm_mapped_concepts = concept_extractor.llm_mapping_key_concepts(key_concepts_embed, embed_mapped_concepts, store_path = f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_llm_with_result_{args.with_result}.json")

                    with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_value_embed_llm_with_result_{args.with_result}.json", "w") as f:
                        json.dump(llm_mapped_concepts, f)

                if not os.path.exists(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_all_embed_llm_with_result_{args.with_result}.json"):
                    llm_mapped_concepts = concept_extractor.llm_mapping_key_concepts_all(key_concepts_embed, llm_mapped_concepts, store_path = f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_llm_with_result_{args.with_result}.json")

                    with open(f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_all_embed_llm_with_result_{args.with_result}.json", "w") as f:
                        json.dump(llm_mapped_concepts, f)
            
            print("Save llm mapped concepts to path: ", f"{result_path}/mapped_concepts_test_{args.data_version}_{args.train_split}_{map_source}_scope_{args.map_scope}_embed_llm_with_result_{args.with_result}.json")