import pandas as pd
from collections import defaultdict
import os
import yaml
from openai import OpenAI
from backend.gpt import query_gpt4, query_gpt
from tqdm import tqdm
import openai
import sys

project_path = os.path.dirname(__file__)
global_config = yaml.safe_load(open(os.path.join(project_path, "config/global.yaml"), "r"))
OPENAI_API_KEY = global_config.get("backend").get("openai_api_key")
BASE_URL = global_config.get("backend").get("base_url", None)

client = OpenAI(
    api_key=OPENAI_API_KEY,
    base_url=BASE_URL,
)

def get_embedding(text, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return client.embeddings.create(input=[text], model=model, dimensions=256).data[0].embedding

def get_embedding_v2(text, model="text-embedding-ada-002"):
    text = text.replace("\n", " ")
    return client.embeddings.create(model=model, input=text, encoding_format="float").data[0].embedding[:256]

def add_scene_index(df_data, df_scene):
    # add scene index to original data based on the first utterance in each scene provided by df_scene
    ret_data = []
    index_scene = 0
    max_scene_index = len(df_scene)
    for idx, line in df_data.iterrows():
        speaker = line[1]
        listener = line[2]
        message = line[3]
        timestamp = line[4]
        if index_scene < max_scene_index and message.strip() == df_scene.iloc[index_scene]['text'].lower().strip():
            index_scene += 1
        ret_data.append([speaker, listener, message, timestamp, index_scene])
    return pd.DataFrame(ret_data, columns=['speaker', 'listener', 'message', 'timestamp', 'index_scene'])


def aggregate_messages(df, person_name):
    # aggregate messages by scene and listener
    relevant_conversations = df[(df['speaker'] == person_name) | (df['listener'] == person_name)]
    
    conversations_dict = defaultdict(list)
    
    for _, row in relevant_conversations.iterrows():
        other_person = row['listener'] if row['speaker'] == person_name else row['speaker']
        
        conversations_dict[other_person + "&" + str(row['index_scene'])].append("from {} to {}: {}".format(row['speaker'], row['listener'], row['message']))

    ret_messages = []
    for person_scene, messages in conversations_dict.items():
        person = person_scene.split("&")[0]
        all_messages_in_one_scene = "Messages with {}: ".format(person) + "\t".join(messages)
        ret_messages.append(all_messages_in_one_scene)
    
    return ret_messages

def aggregate_message_summary(df, person_name):
    # aggregate and summarize messages by scene and listener
    relevant_conversations = df[(df['speaker'] == person_name) | (df['listener'] == person_name)]
    
    conversations_dict = defaultdict(list)
    
    for _, row in relevant_conversations.iterrows():
        other_person = row['listener'] if row['speaker'] == person_name else row['speaker']
        
        conversations_dict[other_person + "&" + str(row['index_scene'])].append("from {} to {}: {}".format(row['speaker'], row['listener'], row['message']))

    ret_messages = []
    for person_scene, messages in tqdm(conversations_dict.items()):
        person = person_scene.split("&")[0]
        query_prompt = """Here is a conversation between you and {}:\n
{}\n
Now please summarize this conversation in detail without ignoring any key information (any entity or fact). 
Remove any offensive content and replace it with the character's attitude.
Return ONLY the summary."""
        query_prompt = query_prompt.format(person, "\t".join(messages))
        try:
            summary = query_gpt4(query_prompt, woretry=True)
            if summary:
                summary_in_one_scene = "Summary of a talk with {}: ".format(person) + summary
                print(summary_in_one_scene)
                ret_messages.append(summary_in_one_scene)
        except Exception:
            import traceback
            traceback_str = traceback.format_exc()
            sys.stdout = sys.__stdout__
            # If there was an error, return the traceback
            print(traceback_str)
            print("Bad Content : {}".format(str("\t".join(messages))))
            continue
    
    return ret_messages


# # ------------------- write memory -------------------
name2message = defaultdict(list)
for idx in range(1, 25):
    print(idx)
    idx_str = str(idx)
    if len(idx_str) == 1:
        idx_str = "0" + idx_str
        df = pd.read_csv("./data/Friends/s01/e{}/s01e{}_labeled_withscene.csv".format(idx_str, idx_str))
        df = df.applymap(lambda x: x.lower() if isinstance(x, str) else str(x))
        all_person = set(df['speaker']) | set(df['listener'])
        for person in all_person:
            name2message[person] += aggregate_messages(df, person)
for person in name2message:
    print(person, len(name2message[person]))
for name in name2message:
    print(name)
    data = [[get_embedding(line), line] for line in name2message[name]]
    data = pd.DataFrame(data)
    data.to_csv("./memory/raw_by_scene_v3/{}.tsv".format(name), header=['emb','text'] , sep="\t")

# # # ------------------- write memory with summary -------------------
# name2message = defaultdict(list)
# for idx in range(1, 25):
#     print(idx)
#     idx_str = str(idx)
#     if len(idx_str) == 1:
#         idx_str = "0" + idx_str
#         df = pd.read_csv("./data/Friends/s01/e{}/s01e{}_labeled_withscene.csv".format(idx_str, idx_str))
#         df = df.applymap(lambda x: x.lower() if isinstance(x, str) else str(x))
#         all_person = set(df['speaker']) | set(df['listener'])
#         for person in all_person:
#             name2message[person] += aggregate_message_summary(df, person)
# for person in name2message:
#     print(person, len(name2message[person]))
# for name in name2message:
#     print(name)
#     data = [[get_embedding(line), line] for line in name2message[name]]
#     data = pd.DataFrame(data)
#     data.to_csv("./memory/summary_by_scene_v3/{}.tsv".format(name), header=['emb','text'] , sep="\t")

# # ------------------- add scene info -------------------
# for idx in range(1, 25):
#     print(idx)
#     idx_str = str(idx)
#     if len(idx_str) == 1:
#         idx_str = "0" + idx_str
#     # 读取 CSV 文件到 DataFrame
#     df_data = pd.read_csv("./data/Friends/s01/e{}/s01e{}_labeled.csv".format(idx_str, idx_str))
#     df_data = df_data.applymap(lambda x: x.lower() if isinstance(x, str) else str(x))
#     df_scene = pd.read_csv("./data/Friends/s01/e{}/scene_info.csv".format(idx_str))
#     df_data_withscene = add_scene_index(df_data, df_scene)
#     df_data_withscene.to_csv("./data/Friends/s01/e{}/s01e{}_labeled_withscene.csv".format(idx_str, idx_str))