import logging
import os
import yaml
import json

from avatarchat.sql import *
from avatarchat.mode import Mode
from avatarchat.util import AvatarLogger

project_path = os.path.dirname(__file__)
global_config = yaml.safe_load(open(os.path.join(project_path, "config/global.yaml"), "r"))

def evaluate(sender, receiver, task_prompt, sample_id, is_load_conclusion):
    log_filename = os.path.join(project_path, "exp", "Friends", global_config.get('logging').get('logname') + "_{}_raw.log".format(sample_id))
    
    logger = logging.getLogger("evaluate_{}".format(sample_id))
    logger.setLevel(global_config.get('logging').get('level'))
    
    file_handler = logging.FileHandler(log_filename, encoding="utf-8")
    formatter = logging.Formatter('[%(asctime)s %(levelname)s]\n%(message)s', datefmt='%Y-%d-%m %H:%M:%S')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    AvatarLogger.set_evaluate_log_path("Friends", sample_id)
    AvatarLogger.set_logger(logger)
    
    mode = Mode(sender=sender, receiver=receiver, task=task_prompt, global_config=global_config, rewrite_prompt=False)
    communication = mode.get_communication(is_offline=True)

    if is_load_conclusion:
        load_json = open("./exp/Friends/FriendsV3_GPT3.5_Full_{}_result.json".format(str(str(sample_id))), "r").read()
        load_json = json.loads(load_json.strip())
        load_conclusion = load_json['communication_history'][1:3]
        communication.set_communication_history(load_conclusion)
    response = communication.communicate()
    
    logger.removeHandler(file_handler)
    file_handler.close()
    del(logger)
    logging.shutdown()
    
    return response, communication


count_rationales = 0
count_comb = 0
count_scene = set()
continue_flag = False
continue_sample_id = "0_s01_e14_c01_0"
is_load_conclusion = False

subset = set()
with open("./data/Friends/FriendsV4_subset", "r") as f:
    for line in f:
        subset.add(line.strip())

with open("./data/Friends/FriendsComQACleanV3.jsonl", "r") as f:
    for overall_idx, line in enumerate(f):
        data = json.loads(line.strip())

        count_scene.add(data['episode'])
        count_comb += len(data['character_combinations'])
        count_rationales += len(data['rationale_list'])

        print(data['generated_question'], data['question_v2'])
        print(data['answer'])
        print(data['character_combinations'])

    
        for idx, characters in enumerate(data['character_combinations']):
            print(characters)
            sample_id = str(overall_idx) + "_" + data['episode'] + "_" + str(idx)
            print(sample_id)

            if sample_id == continue_sample_id:
                continue_flag = True

            if not continue_flag:
                continue
            
            if sample_id not in subset:
                continue

            answer, communication = evaluate(characters[0], characters[1], data['question_v2'], sample_id, is_load_conclusion)
            data['question_sender'] = characters[0]
            data['question_receiver'] = characters[1]
            data['predicted_answer'] = answer
            data['sample_id'] = sample_id
            data['communication_history'] = communication.communication_history
            with open(os.path.join(project_path, "exp", "Friends", global_config.get('logging').get('logname') + "_{}_result.json".format(sample_id)), "w") as json_file:
                json.dump(data, json_file)

print(count_rationales)
print(count_comb)
print(len(count_scene))