import json
import pandas as pd
import jsonlines

sample_idx = 0

def text2list(text, sample_idx):
    ret = []
    l = text.split("\n")
    l = [item for item in l if len(item) > 0 ]
    l = [(item.split(": ")[0], " ".join(item.split(": ")[1:])) for item in l]
    all_people = set([item[0] for item in l])
    for line in l:
        sender = line[0].lower() + "_" + str(sample_idx)
        receiver = list(all_people - set([line[0]]))[0].lower() + "_" + str(sample_idx)
        message = line[1]
        ret.append([sender, receiver, message])
    return ret

data_list = []
fw = jsonlines.open("./dataset_all_dbformat.jsonl", "w")
with open("./dataset_all.jsonl", "r") as f:
    for line in f:
        data = json.loads(line.strip())
        if "needle_detail" not in data:
            data['needle_detail'] = data["task_prompt"] + ": " + data["answer"]
        ab = data['modified_alice_bob_conversation']
        cd = data['modified_charlie_dave_conversation']
        bc = data['chat_bob_charlie']
        sample_data_list = []
        data_list += text2list(ab, sample_idx)
        data_list += text2list(cd, sample_idx)
        data_list += text2list(bc, sample_idx)
        sample_data_list += text2list(ab, sample_idx)
        sample_data_list += text2list(cd, sample_idx)
        sample_data_list += text2list(bc, sample_idx)
        df = pd.DataFrame(sample_data_list, columns=['sender','receiver','message'])
        set_groundtruth = set(['alice_{}'.format(str(sample_idx)), 'bob_{}'.format(str(sample_idx)), 'charlie_{}'.format(str(sample_idx)), 'dave_{}'.format(str(sample_idx))])
        assert set(df['sender']) | set(df['receiver']) == set_groundtruth, print(set(df['sender']) | set(df['receiver']) , set_groundtruth)
        sample_idx += 1
        data['db_format'] = str(df.values.tolist())
        fw.write(data)
fw.close()

df = pd.DataFrame(data_list, columns=['sender','receiver','message'])
df.to_csv("./dataset_all_dbformat.csv")
