import pandas as pd
import sys
sys.path.append("../..")
from datetime import datetime, timedelta
import pandas as pd
from backend.gpt import query_gpt4
import sys
import os
import logging


# load raw data
df = pd.read_csv('Friends.csv', delimiter=',', header=0)
df = df.applymap(lambda x: x.lower() if isinstance(x, str) else x)

# pick episode
season_num = str(sys.argv[1])
episode_num = str(sys.argv[2])

# create output folder
output_path = "./s{}/e{}".format(season_num, episode_num)
try:
    os.makedirs(output_path)
    print("output to {}".format(output_path))
except FileExistsError:
    print("{} created, output would be overwrite".format(output_path))
except OSError as e:
    print("Failed to create output path with {}, {}".format(output_path, e))
    exit(1)

logging.basicConfig(filename=output_path + "/data_generation.log",
                    level="INFO",
                    format='[%(asctime)s]\n%(message)s',
                    datefmt='%Y-%d-%m %H:%M:%S',
                    encoding="utf-8")

sample = df[df['Season'].str.contains("season-" + season_num)]
sample = sample[sample['Episode'].str.contains("episode-" + episode_num)].dropna().reset_index(drop=True)

logging.info("Picked Episode")
logging.info(sample)

# in some scripts the writer use abbr, norm it back
def norm_name(name):
    name = name.strip()
    orig = name
    if name == "phoe" or "phoe's" in name or "phoe," in name or "phoe/" in name or ",phoe" in name or "/phoe" in name:
        name = name.replace("phoe", "phoebe")
    if name == "rach" or "rach's" in name or "rach," in name or "rach/" in name or ",rach" in name or "/rach" in name:
        name = name.replace("rach", "rachel")
    if "mnca" in name:
        name = name.replace("mnca", "monica")
    if name == "chan" or "chan's" in name or "chan," in name or "chan/" in name or ",chan" in name or "/chan" in name:
        name = name.replace("chan", "chandler")
    if orig != name:
        logging.info("{} --> {}".format(orig, name))    
    return name

# locate scene
scene_flag = True
# some data are dirty with '[scene' in the cell
scene_indexs = []
for idx, data in sample.iterrows():
    if "[scene" in data["Speaker"]:
        scene_indexs.append(idx)
# usually it contains only 'scene'
if len(scene_indexs) == 0:
    for idx, data in sample.iterrows():
        if "scene" in data["Speaker"]:
            scene_indexs.append(idx)
# some scripts contains no scene, in this case we evenly split the script to 10-fold to ensure the precision of GPT labeler
if len(scene_indexs) == 0:
    scene_indexs = list(range(len(sample)))
    scene_indexs = scene_indexs[:: len(sample) // 10]
    logging.info("No explicit scene info, evenly divided into 10 scenes")
    scene_flag = False

logging.info("Index of each scene in the script")
logging.info(scene_indexs)

# write scene info
scene_pd = []
for idx in scene_indexs:
    # locate the first utterance in each scene
    try:
        if scene_flag:
            scene_pd.append([idx, sample.iloc[idx+1]['Text']])
        else:
            scene_pd.append([idx, sample.iloc[idx]['Text']])
    except IndexError:
        continue
scene_pd = pd.DataFrame(scene_pd, columns=['index','text'])
scene_pd.to_csv(os.path.join(output_path, "scene_info.csv"))

# query gpt to label the listener of conversation
for scene_idx in range(len(scene_indexs)):
    logging.info("Scene {}".format(str(scene_idx + 1)))
    start_index = scene_indexs[scene_idx]
    if scene_idx == len(scene_indexs) - 1:
        end_index = len(sample)
    else:
        end_index = scene_indexs[scene_idx + 1]
    if scene_flag:
        scene_info = sample.iloc[start_index]['Speaker'] + str(scene_idx+1)
        background_info = sample.iloc[start_index]['Text']
        episode_info = sample.iloc[start_index]['Episode']
        season_info = sample.iloc[start_index]['Season']
    else:
        scene_info = ""
        background_info = ""
        episode_info = sample.iloc[start_index]['Episode']
        season_info = sample.iloc[start_index]['Season']
    logging.info("From {} to {} in scripts, {} lines".format(str(start_index), str(end_index), str(end_index - start_index + 1)))
    
    content = ["Index\tUtterance\tSpeaker"]
    if scene_flag:
        start_index += 1
    scene_part = sample.iloc[start_index:end_index].reset_index(drop=True)
    utterance_index = 1
    for _, line in scene_part.iterrows():
        content.append(str(utterance_index) + "\t" + line['Text'] + "\t" + line['Speaker'])
        utterance_index += 1
    content = "\n".join(content)
    query_prompt = """this is a script of {} in the popular tv series Friends, {} in {}. 
the background of this scene is {}. Here is the script:

{}

each line is the index and utterance with the speaker of this utterance.
you need to annotate the listener for each utterance based on the context of dialogue and scene description.
the utterances are not organized in speaker-listener-speaker mode, there may be multiple consecutive queries from multiple speakers to the same listener.
so you must annotate based on the content of dialogue.
if there are multiple listener for one utterance, annotate in the format of "name1/name2/name3"
if all character in this scene are the listeners for one utterance, annotate it with "ALL"
if there exist pronouns and references, you need to annotate precise attribution, e.g, "father" should be annotate as "someone's father"
you have to give answer. if you are not sure, annotate the listener as "ALL"
the returned index and speaker should be exactly the same as the given script, all you need to do is only add the listener
return in the format like:
index, speaker, listener
1, the name of speaker, the name of listener""".format(scene_info, episode_info, season_info, background_info, content)
    # logging.info(query_prompt)
    # logging.info("-" * 20)
    ans = query_gpt4(query_prompt)
    ans_list = []
    for line in ans.lower().split("\n"):
        if "Utterance" not in line:
            ans_list.append(line.split(","))
    
    final_list = []
    for idx, line in enumerate(ans_list):
        utterance_index = int(line[0])
        speaker_input = scene_part.iloc[utterance_index - 1]['Speaker'].strip()
        speaker_output = line[1].strip()
        listener = line[2].strip()
        if speaker_input != speaker_output:
            logging.info("{}: {}[Input Speaker] != {}[Output Speaker]\nOriginal Input:\n{}\nFull Output:\n{}\nmodel output will be used as speaker".format(str(utterance_index), 
                                                                                                                                                           speaker_input, 
                                                                                                                                                           speaker_output, 
                                                                                                                                                           str(scene_part.iloc[utterance_index - 1]),
                                                                                                                                                           str(line)))
            if len(speaker_output) > 0:
                speaker_input = speaker_output
        if len(line) != 3:
            logging.info("len({}) !=3: Bad Generated utterance, skip".format(str(utterance_index), speaker_input, speaker_output))
            continue
        final_list.append(norm_name(speaker_input) + "\t" + norm_name(listener) + "\t" + scene_part.iloc[utterance_index - 1]['Text'])
        
    with open("{}/s{}e{}_scene{}.tsv".format(output_path, season_num, episode_num, str(scene_idx + 1)), "w") as f:
        for data in final_list:
            f.write(str(data) + "\n")

# generate timestamps for offline conversation dataset
start_time = datetime(1994, 9, 22, 0, 0, 0)
time_interval = 1
current_time = start_time

# data clean for gpt-labeled dataset
# 1. replace all to all other character in this scene
# 2. norm the situation like "the monica" - > "monica", "rachel and/+/& monice" to "rachel" and "monica"
# 3. explode the dataframe, each line with only one speaker and one listener
def process_df(df):
    global current_time
    all_people = set()
    for item in set(df[0]) - set(["all"]):
        item = item.replace("the ", "")
        if " and " in item:
            all_people = all_people | (set(item.split(" and ")))
        elif "+" in item:
            all_people = all_people | (set(item.split("+")))
        elif " & " in item:
            all_people = all_people | (set(item.split(" & ")))        
        else:
            all_people.add(item)
    
    if len(all_people) <= 2:
        all_people = set(["rachel", "monica", "phoebe", "ross", "joey", "chandler"])
    ret = []
    for _, line in df.iterrows():
        mysql_timestamp = current_time.strftime('%Y-%m-%d %H:%M:%S')
        current_time += timedelta(seconds=time_interval)
        sender = line[0].strip().replace("the ", "")
        receiver = line[1].strip().replace("the ", "")
        sender = sender.replace(" and ", "/")
        receiver = receiver.replace(" and ", "/")
        sender = sender.replace("+", "/")
        receiver = receiver.replace("+", "/")
        sender = sender.replace(" & ", "/")
        receiver = receiver.replace(" & ", "/")
        if "all" in receiver:
            receiver = receiver.replace("all", "/".join(all_people - set(sender.split("/"))))
        if "all" in sender:
            sender = sender.replace("all", "/".join(all_people - set(receiver.split("/"))))

        sender = sender.strip()
        receiver = receiver.strip()
        for character_sender in sender.split("/"):
            for character_receiver in receiver.split("/"):
                ret.append([character_sender, character_receiver, line[2], mysql_timestamp])
    return pd.DataFrame(ret)


file_name_prefix = "{}/s{}e{}_scene".format(output_path, season_num, episode_num)
df_list = []
for idx in range(1, len(scene_indexs) + 1):
    try:
        df = pd.read_csv(file_name_prefix + str(idx) + ".tsv", delimiter='\t', header=None)
    except pd.errors.EmptyDataError:
        logging.info("Empty tsv: {} skipped".format(file_name_prefix + str(idx) + ".tsv"))
    df = df.applymap(lambda x: x.strip() if isinstance(x, str) else x)
    df_list.append(process_df(df))

concat_df = pd.concat(df_list, axis=0, ignore_index=True)
all_character = set(concat_df[0]) | set(concat_df[1])
logging.info("All characters in this episode")
logging.info(all_character)
concat_df.to_csv("{}/s{}e{}_labeled.csv".format(output_path, season_num, episode_num))
logging.info("final cleaned data")
logging.info(concat_df)