import json
import sys
import re
sys.path.append("../..")
from pprint import pprint
from backend.gpt import query_gpt, query_gpt4
import pandas as pd
from collections import defaultdict
import jsonlines
from itertools import combinations


def norm(text):
    return text.lower().strip().replace(" ", "")

last_breakpoint = "s01_e16_c13"
continue_flag = False

# load QA dataset
train_data = json.loads(open("path_to_FriendsQA/friendsqa_trn.json", "r").read())
dev_data = json.loads(open("path_to_FriendsQA/friendsqa_dev.json", "r").read())
test_data = json.loads(open("path_to_FriendsQA/friendsqa_tst.json", "r").read())

# load labeled dialogue dataset (label speaker and listener for each utterance)
concat_df = pd.read_csv("s01.csv")
utterance2character = defaultdict(set)
for _, line in concat_df.iterrows():
    utterance2character[norm(line['message'])].add(line['speaker'])
    utterance2character[norm(line['message'])].add(line['listener'])

# output
fw = jsonlines.open("./FriendsComQA.jsonl", "a")

good_case = []

for data in train_data['data']:
    if data['title'] == last_breakpoint:
        continue_flag = True
    
    if not continue_flag:
        continue

    if "s01" not in data['title']:
        continue
    
    # usually only one paragraph
    for paragraph in data['paragraphs']:
        # all the information in one scene
        qas_list = paragraph['qas']
        utterances_list = paragraph['utterances:']
        dialogue = ["/".join(item['speakers']) + ": " + item['utterance'] for item in utterances_list]

        # filter out the scene with only two participants
        all_participants = set()
        for item in utterances_list:
            all_participants = all_participants | set(item['speakers'])
        if len(all_participants) == 2:
            continue

        # mappings
        uid2qas = defaultdict(list) # uid --> all qa-pairs using this utterance as rationale
        idx2qa = defaultdict(str) # qa-id --> qa-pair
        idx2speaker = defaultdict(str) # qa-id --> the speaker of the utterance which is the rationale of this qa
        idx2uid = defaultdict(str) # qa-id --> uid of the utterance which is the rationale of this qa

        # construct all mappings of all qa-pairs in this scene
        for qidx, qas in enumerate(qas_list):
            for idx, answer in enumerate(qas['answers']):
                uid2qas[answer['utterance_id']].append("Q_{}_A_{}".format(str(qidx), str(idx)))
                idx2uid["Q_{}_A_{}".format(str(qidx), str(idx))] = answer['utterance_id']
                idx2speaker["Q_{}_A_{}".format(str(qidx), str(idx))] = utterances_list[answer['utterance_id']]['speakers'][0].lower().split(" ")[0]
                idx2qa["Q_{}_A_{}".format(str(qidx), str(idx))] = qas['question'] + ": " + answer['answer_text']
        
        # at least there are two rationale utterance who has other people other than the speaker of these utterance so we can combine and generate question under information asymmetry scenario
        count = 0
        for idx, item in enumerate(utterances_list):
            if len(uid2qas[idx]) > 0 and len(utterance2character[norm(item['utterance'])]) >= 2:
                count += 1

        if count >= 2:
            print("\n" * 3)
            print(data['title'])
            print('-' * 10)
            qas_set = list()
            rationale_list = list()
            rationale_list_print = list()
            # for each utterance
            for idx, item in enumerate(utterances_list):
                # we mark all rationale utterance with third-party participants
                if len(uid2qas[idx]) > 0 and len(utterance2character[norm(item['utterance'])]) >= 2:
                    rationale_list_print.append("[{}] {}: {} \t {} \t {}".format(str(item['uid']), "/".join(item['speakers']),
                                                                                item['utterance'], str(uid2qas[idx]),
                                                                                str(utterance2character[norm(item['utterance'])])))
                    # save all utterance with rationale and qas to file
                    rationale_list.append({
                        "rationale_uid": item['uid'],
                        "speaker": item['speakers'],
                        "rationale_utterance": item['utterance'],
                        "related_qas": uid2qas[idx],
                        "related_character": list(utterance2character[norm(item['utterance'])])
                    })
                    qas_set += uid2qas[idx]
            
            # for all qa-pairs in this scene
            # we check every combination
            qas_set = set(qas_set)
            for comb in combinations(qas_set, 2):
                if comb[0].split("_")[1] != comb[1].split("_")[1]:  # different question

                    # if there is no differences among two rationales, skip
                    if len(utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])]) == 0:
                        continue
                    if len(utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])]) == 0:
                        continue

                    print("-" * 20)
                    query_prompt = """
Here is a conversation:
{}
Here is a question-answer pair on this conversation:
{}
Here is another question-answer pair on this conversation:
{}
Please generate a new question-answer pair based on the given two question-answer pairs.
The new question-answer pair must satisfies the following requirements:
a. You can only ask one question in the new generated question.
b. The new question can be answered only if both answers of the given two questions are known. Both answers are directly necessary for solving the new question.
c. The answer to this new question is simple, certain and unambiguous, usually only one word or phrase, which can be easily solved once the two answers to the above questions are known.
d. The new question should contain as much detailed description on the context/event/situation as possible but not leaking any answers.

Here are some good cases for generating the new questions

{}


Return in the format of:         
1. the new question
2. explain in detail that why it must needs answers to the given first questions to solve, and how it be solved
3. explain in detail that why it must needs answers to the given second questions to solve, and how it be solved
4. the answer to the new question
5. how to induce/deduce the answer to the new question from the answers of given two questions, which are "{}" and "{}"
"""
                    query_prompt = query_prompt.format("\n".join(dialogue), idx2qa[comb[0]], idx2qa[comb[1]], "\n".join(good_case), idx2qa[comb[0]].split(": ")[1], idx2qa[comb[1]].split(": ")[1])
                    print(query_prompt)
                    print("-" * 20)
                    print("QA Combination: ", comb)
                    # print("QA1: ", idx2qa[comb[0]])
                    # print("QA2: ", idx2qa[comb[1]])
                    involve_speakers = set([idx2speaker[comb[0]], idx2speaker[comb[1]]])

                    # print("QA1 Speaker&Participants: ", idx2speaker[comb[0]], utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])])
                    # print("QA2 Speaker&Participants: ", idx2speaker[comb[1]], utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])])

                    if len(utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - involve_speakers) == 0:
                        continue
                    if len(utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - involve_speakers) == 0:
                        continue
                    
                    show_df = []
                    show_df.append([idx2qa[comb[0]].split(": ")[0], idx2qa[comb[0]].split(": ")[1], idx2speaker[comb[0]], utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - involve_speakers,  idx2uid[comb[0]]])
                    show_df.append([idx2qa[comb[1]].split(": ")[0], idx2qa[comb[1]].split(": ")[1], idx2speaker[comb[1]], utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - involve_speakers,  idx2uid[comb[1]]])
                    show_df = pd.DataFrame(show_df, columns=['Question','Answer','Speaker','Unique Participants', "Rationale Uid"])
                    print("*" * 10)
                    print(show_df)
                    print("*" * 10)

                    print("Rationale Utterance:")
                    print("\n".join(rationale_list_print))

                    # check if the participants and questions satisfies the need
                    s = input("y for continue, n for refuse, s for skip this scene\n")
                    if s == "y":
                        pass
                    elif s == "n":
                        continue
                    elif s == "s":
                        break
                    else:
                        continue

                    print(">" * 10)
                    response = query_gpt4(query_prompt, temperature=0.1)
                    matches = re.findall(r'\d+\.\s*(.*)', response)
                    response_list = []
                    for match in matches:
                        response_list.append(match.strip())
                    print(response)
                    print()

                    # check if the generated question and answer satisfies the need
                    s = input("y for accept, n for refuse, s for skip this scene, g for good case\n")
                    if s == "y" or s == "g":
                        ret_dict = {
                            "episode":
                                data['title'],
                            "context":
                                dialogue,
                            "rationale_list":
                                rationale_list,
                            "QA1": {
                                "idx": comb[0],
                                "content": idx2qa[comb[0]]
                            },
                            "QA2": {
                                "idx": comb[1],
                                "content": idx2qa[comb[1]]
                            },
                            "generated_question":
                                response_list[0],
                            "reason1":
                                response_list[1],
                            "reason2":
                                response_list[2],
                            "answer":
                                response_list[3],
                            "QA1_third_party":
                                list(utterance2character[norm(utterances_list[idx2uid[comb[0]]]['utterance'])] - involve_speakers),
                            "QA2_third_party":
                                list(utterance2character[norm(utterances_list[idx2uid[comb[1]]]['utterance'])] - involve_speakers),
                        }
                        pprint(ret_dict)
                        fw.write(ret_dict)
                        if s == "g":
                            good_case.append("{} + {} --> {}, why is a good combination: {}".format(idx2qa[comb[0]], idx2qa[comb[1]], response_list[0] + ": " + response_list[3], response_list[4]))
                        s = input("finished, write to file, enter to continue")
                    elif s == "s":
                        break
                    else:
                        continue
