import json
import re

import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm
tqdm.pandas(desc='Progress')

import openai_api
import anthropic_api
if 'llama-' in [config['model1'][:6], config['model2'][:6]]:
    import llama2_api
if 'zephyr' in [config['model1'][:6], config['model2'][:6]]:
    import zephyr_api
from persona import persona_inter

with open('../data/hotpotqa/sampled_150.json', 'r') as f:
    hotpotqa = json.load(f)

data = []
for d in hotpotqa:
    context = {}
    for c in d['context']:
        context[c[0]] = c[1]
    doc = 'Context:\n'
    sfs = list(set([sf[0] for sf in d['supporting_facts']]))
    for sf in sfs:
        doc += 'Title: ' + sf + '\n'
        doc += ''.join(context[sf])
        doc += '\n\n'
    question = 'Question: ' + d['question'].strip()
    new_data = {}
    new_data['id'] = d['_id']
    new_data['context'] = doc
    new_data['question'] = question
    new_data['answer'] = d['answer']
    data.append(new_data)

sys_message = {
#    'normal': 'You are trying to answer the given question with given contexts. Please ask sub-questions to an assistant to approach answers step-by-step. In each turn, please only ask one sub-question to the assistant. In the sub-questions, please include all necessary information, such as the question. After one turn of conversation, if you know the answer, please output a short and concise answer (only few tokens) by following the format "So, the answer is: <answer>."',
    'normal': 'You are trying to answer the given question. Please ask sub-questions to approach answers. In each turn, please only ask one sub-question to the assistant. In the sub-questions, please include all necessary information, such as the question. After one turn of conversation, if you know the answer, please output a short and concise answer (only few tokens) by following the format "So, the answer is: <answer>."',
    'last_turn': 'Please output a short and concise answer (only few tokens) to the question by following the format "So, the answer is: <answer>.'
}
if 'general' != config['persona']:
    sys_message['normal'] = ' '.join([persona_inter[config['persona']], sys_message['normal']])

def extract_line(line):
    messages = [{'role': 'system', 'content': sys_message['normal']}]
    user_queries = []
    lm_responses = []
    user_answer = ''
#    print(sys_message['normal'])

#    question_text = line['context'] + line['question']
    question_text = line['question']
    messages.append({'role': 'user', 'content': question_text})
#    print(question_text)

    max_turns = config['max_turn']
    for turn in range(max_turns):
        # mimic human
        if turn == max_turns - 1:
            messages[0]['content'] = sys_message['last_turn']
        if 'claude' == config['model1'][:6]:
            prediction = anthropic_api.call(messages)
        elif 'gpt-' == config['model1'][:4]:
            prediction = openai_api.call_chat(messages, config['model1']).strip()
        else:
            prompt = '\n\n'.join([messages[0]['content'], messages[1]['content']])
            conversation = ''
            for idx, message in enumerate(messages[2:]):
                prefix = 'You:' if 0 == idx % 2 else 'Assistant:'
                message = ' '.join([prefix, message['content'].strip()]) + '\n' 
                conversation = '\n'.join([conversation, message])
            conversation += 'You: '
            prompt = '\n\n'.join([prompt, conversation])
            prediction = openai_api.call_completion(prompt, config['model1']).strip()
#        print(config['model1'], prediction)
        
        # the answer is: ABCD
        matches = re.search(r'the answer is[:]? (.*)', prediction)
        if matches is not None:
            messages.append({'role': 'assistant', 'content': matches.group(1)})
            user_answer = matches.group(1)
            break

        # normal question
        if '?' != prediction[-1]:
            prediction += '?'
        user_queries.append(prediction)
        messages.append({'role': 'assistant', 'content': prediction})

        # model response
        prompt = line['context'] + 'Question: ' + prediction
        model2_messages = [{'role': 'user', 'content': prompt}]
        if 'claude' == config['model2'][:6]:
            prediction = anthropic_api.call(model2_messages).strip()
        elif 'gpt-' == config['model2'][:4]:
            prediction = openai_api.call_chat(model2_messages, config['model2']).strip()
        elif 'llama-' == config['model2'][:6]:
            prediction = llama2_api.call(model2_messages).strip()
        elif 'zephyr' == config['model2'][:6]:
            prediction = zephyr_api.call(model2_messages).strip()
        else:
            prediction = openai_api.call_completion(prompt, config['model2']).strip()
        lm_responses.append(prediction)
#        print(config['model2'], prediction)
        messages.append({'role': 'user', 'content': prediction})
    if '' == user_answer:
        user_answer = user_queries[-1]

    line['choice'] = prediction
    line['user_queries'] = user_queries
    line['lm_responses'] = lm_responses
    line['user_answer'] = user_answer
    return line

results = []
for d in tqdm(data):
    predictions = extract_line(d)
    predictions['worker_id'] = d['id']
    results.append(predictions)

file_name = '../results/hotpot_conversation_{model1}_{model2}{persona}_prompt-2.json'.format(
                model1=config['model1'],
                model2=config['model2'],
                persona='_'+config['persona'] if 'general'!=config['persona'] else ''
                )
with open(file_name, 'w') as f:
    json.dump(results, f, indent=2)

