import json
import re

import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm
tqdm.pandas(desc='Progress')

import openai_api
import anthropic_api
from persona import persona_inter

data = pd.read_csv('../data/event_blocks.csv')
data = data['lm' == data['question_type']]

data = data[['question_text', 'choice_a', 'choice_b', 'choice_c', 'choice_d', 'answer_text']].drop_duplicates()
#data = data.iloc[:2]

sys_message = {
    'normal': 'You are trying to choose the correct answer to the given question. Please ask sub-questions to approach answers. In each turn, please only ask one sub-question to interact with an assistant. In the sub-questions, please include all necessary information, such as the question and options, in the original question. If you know the answer, please output "So, the answer is: A, B, C, or D."',
    'last_turn': 'Please choose the correct answer to the given question. Please output "So, the answer is: A, B, C, or D."'
}
if 'general' != config['persona']:
    sys_message['normal'] = ' '.join([persona_inter[config['persona']], sys_message['normal']])

def extract_line(line):
    messages = [{'role': 'system', 'content': sys_message['normal']}]
    user_queries = []
    lm_responses = []
    user_answer = ''

    question_text = 'Question:\n{q}\nA.{a}\nB.{b}\nC.{c}\nD.{d}'.format(
                        q=line['question_text'],
                        a=line['choice_a'],
                        b=line['choice_b'],
                        c=line['choice_c'],
                        d=line['choice_d'],
                    )
    messages.append({'role': 'user', 'content': question_text})
    print(question_text)

    max_turns = config['max_turn']
    for turn in range(max_turns):
        # mimic human
        if turn == max_turns - 1:
            messages[0]['content'] = sys_message['last_turn']
        if 'claude' == config['model1'][:6]:
            prediction = anthropic_api.call(messages)
        elif 'gpt-' == config['model1'][:4]:
            prediction = openai_api.call_chat(messages, config['model1']).strip()
        else:
            prompt = '\n\n'.join([messages[0]['content'], messages[1]['content']])
            conversation = ''
            for idx, message in enumerate(messages[2:]):
                prefix = 'You:' if 0 == idx % 2 else 'Assistant:'
                message = ' '.join([prefix, message['content'].strip()]) + '\n' 
                conversation = '\n'.join([conversation, message])
            conversation += 'You: '
            prompt = '\n\n'.join([prompt, conversation])
            prediction = openai_api.call_completion(prompt, config['model1']).strip()
        print(config['model1'], prediction)
        
        # A,B,C,D
        if prediction.upper() in ['A', 'B', 'C', 'D']:
            messages.append({'role': 'assistant', 'content': prediction})
            user_answer = prediction
            break
        # the answer is: ABCD
        matches = re.search(r'the answer is: ([ABCD])', prediction)
        if matches is not None:
            messages.append({'role': 'assistant', 'content': matches.group(1)})
            user_answer = matches.group(1)
            break
#        # ABCD. xxxx
#        matches = re.search(r'\b([ABCD])\.', prediction)
#        if matches is not None:
#            messages.append({'role': 'assistant', 'content': matches.group(1)})
#            user_answer = matches.group(1)
#            break
        # The correct answer is ABCD
        matches = re.search('^The correct answer is ([ABCD])', prediction)
        if matches is not None:
            messages.append({'role': 'assistant', 'content': matches.group(1)})
            user_answer = matches.group(1)
            break

        # normal question
        if '?' != prediction[-1]:
            prediction += '?'
        user_queries.append(prediction)
        messages.append({'role': 'assistant', 'content': prediction})

        # model response
        prediction = openai_api.call_completion(prediction, config['model2']).strip()
        lm_responses.append(prediction)
        print(config['model2'], prediction)
        messages.append({'role': 'user', 'content': prediction})
    line['choice'] = prediction
    line['user_queries'] = user_queries
    line['lm_responses'] = lm_responses
    line['user_answer'] = user_answer
    line['question_type'] = 'lm'
#    input()
    return line

predictions = data.progress_apply(extract_line, axis=1)
predictions['worker_id'] = predictions.index

file_name = '../results/conversation_{model1}_{model2}{persona}_prompt-1.csv'.format(
                model1=config['model1'],
                model2=config['model2'],
                persona='_'+config['persona'] if 'general'!=config['persona'] else ''
                )
predictions.to_csv(file_name)

