import logging
import os
import yaml
import json

from avatarchat.sql import *
from avatarchat.mode import Mode
from avatarchat.util import AvatarLogger

project_path = os.path.dirname(__file__)
global_config = yaml.safe_load(open(os.path.join(project_path, "config/global.yaml"), "r"))

def evaluate(sender, receiver, task_prompt, sample_id, is_load_conclusion):
    log_filename = os.path.join(project_path, "exp", "Schedule", global_config.get('logging').get('logname') + "_{}_raw.log".format(sample_id))
    
    logger = logging.getLogger("evaluate_{}".format(sample_id))
    logger.setLevel(global_config.get('logging').get('level'))
    
    file_handler = logging.FileHandler(log_filename, encoding="utf-8")
    formatter = logging.Formatter('[%(asctime)s %(levelname)s]\n%(message)s', datefmt='%Y-%d-%m %H:%M:%S')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    AvatarLogger.set_evaluate_log_path("Schedule", sample_id)
    AvatarLogger.set_logger(logger)
    
    mode = Mode(sender=sender, receiver=receiver, task=task_prompt, global_config=global_config, rewrite_prompt=False)
    communication = mode.get_communication(is_offline=True)

    if is_load_conclusion:
        load_json = open("./exp/Schedule/Schedule_Simple_GPT3.5_DoubleMind_{}_result.json".format(str(str(sample_id))), "r").read()
        load_json = json.loads(load_json.strip())
        load_conclusion = load_json['communication_history'][1:3]
        assert "A{}".format(str(sample_id)) in load_conclusion[0]
        assert "B{}".format(str(sample_id)) in load_conclusion[0]
        communication.set_communication_history(load_conclusion)
    response = communication.communicate()
    
    logger.removeHandler(file_handler)
    file_handler.close()
    del(logger)
    logging.shutdown()
    
    return response, communication



count_rationales = 0
count_comb = 0
count_scene = set()
continue_flag = False
continue_sample_id = 0
sample_id = -1
is_load_conclusion = False
with open("./data/Schedule/dataset_hard.jsonl", "r") as f:
    for overall_idx, line in enumerate(f):
        sample_id += 1
        data = json.loads(line.strip())

        print(sample_id)
        print(data['question'], data['answer'])

        if sample_id == continue_sample_id:
            continue_flag = True
        
        if not continue_flag:
            continue
            
        answer, communication = evaluate(data['QA agents'][0], data['QA agents'][1], data['question'], sample_id, is_load_conclusion)
        data['predicted_answer'] = answer
        data['sample_id'] = sample_id
        data['communication_history'] = communication.communication_history
        with open(os.path.join(project_path, "exp", "Schedule", global_config.get('logging').get('logname') + "_{}_result.json".format(sample_id)), "w") as json_file:
            json.dump(data, json_file)
        