import json
import re

import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm
tqdm.pandas(desc='Progress')

import openai_api
import anthropic_api
if 'llama-' in [config['model1'][:6], config['model2'][:6]]:
    import llama2_api
if 'zephyr' in [config['model1'][:6], config['model2'][:6]]:
    import zephyr_api
from persona import persona_inter

with open('../data/ambig/sampled_150.json', 'r') as f:
    data = json.load(f)

model1_sys_message = {
    'normal': persona_inter[config['persona']] + ' If you receieve a short response (a few tokens and no multiple options inside) that answers the question without ambiguous, please output "This is the answer I want." Otherwise, please reply to questions from the assistant and clarify your question.',
    'last_turn': persona_inter[config['persona']] + ' Please ask the last question. Otherwise, please output "This is the answer I want."'
}
model2_sys_message = 'You are a helpful assistant. Questions may be ambiguous and vague on details, such as mentioning a name shared by multiple songs or media or including multiple sub-names. If you find these problems, please do not answer it and ask a question, including options, to the user to clarify them until you think the question is clear and accurate. At last, please answer the question in a short format (a few tokens).'

def extract_line(line):
    model1_messages = [{'role': 'system', 'content': model1_sys_message['normal']}]
    model2_messages = [{'role': 'system', 'content': model2_sys_message}]
    user_queries = []
    lm_responses = []
    user_answer = ''
    print(model1_sys_message['normal'])

    max_turns = config['max_turn']
    for turn in range(max_turns):
        # mimic human
        if 0 < turn:
            if turn == max_turns - 1:
                model1_messages[0]['content'] = model1_sys_message['last_turn']
            if 'claude' == config['model1'][:6]:
                prediction = anthropic_api.call(model1_messages).strip()
            elif 'gpt-' == config['model1'][:4]:
                prediction = openai_api.call_chat(model1_messages, config['model1']).strip()
            else:
                prompt = '\n\n'.join([model1_messages[0]['content'], model1_messages[1]['content']])
                conversation = ''
                for idx, message in enumerate(model1_messages[2:]):
                    prefix = 'You:' if 0 == idx % 2 else 'Assistant:'
                    message = ' '.join([prefix, message['content'].strip()]) + '\n' 
                    conversation = '\n'.join([conversation, message])
                conversation += 'You: '
                prompt = '\n\n'.join([prompt, conversation])
                prediction = openai_api.call_completion(prompt, config['model1']).strip()
            print(config['model1'], prediction)
            
            matches = re.search(r'This is the answer I want[.]?', prediction, re.I)
            if matches is not None:
                model1_messages.append({'role': 'assistant', 'content': prediction})
                user_answer = lm_responses[-1]
                break

            model1_messages.append({'role': 'assistant', 'content': prediction})
        else:
            model1_messages.append({'role': 'user', 'content': 'Please input your question:'})
            question_text = line['question']
            model1_messages.append({'role': 'assistant', 'content': question_text})
            print(json.dumps(line['annotations'], indent=2))
            print(config['model1'], question_text)
            prediction = question_text
        user_queries.append(prediction)
#        model2_messages = [{'role': 'system', 'content': model2_sys_message}]
        model2_messages.append({'role': 'user', 'content': prediction})

        # model response
        if 'claude' == config['model2'][:6]:
            prediction = anthropic_api.call(model2_messages).strip()
        elif 'gpt-' == config['model2'][:4]:
            prediction = openai_api.call_chat(model2_messages, config['model2']).strip()
        elif 'llama-' == config['model2'][:6]:
#            prediction = model2_sys_message + 'Question: ' + prediction
#            prediction = llama2_api.call(prediction).strip()
            prediction = llama2_api.call(model2_messages).strip()
        elif 'zephyr' == config['model2'][:6]:
            prediction = zephyr_api.call(model2_messages).strip()
        else:
            prediction = model2_sys_message + 'Question: ' + prediction
            prediction = openai_api.call_completion(prediction, config['model2']).strip()
        lm_responses.append(prediction)
        print(config['model2'], prediction)
        model1_messages.append({'role': 'user', 'content': prediction})
        model2_messages.append({'role': 'assistant', 'content': prediction})
    if '' == user_answer:
        user_answer = lm_responses[-1]

    line['choice'] = prediction
    line['user_queries'] = user_queries
    line['lm_responses'] = lm_responses
    line['user_answer'] = user_answer
#    input()
    return line

results = []
for d in tqdm(data):
    predictions = extract_line(d)
    predictions['worker_id'] = d['id']
    results.append(predictions)

file_name = '../results/ambig_conversation_{model1}_{model2}{persona}_prompt-1.json'.format(
                model1=config['model1'],
                model2=config['model2'],
                persona='_'+config['persona'] if 'general'!=config['persona'] else ''
                )
with open(file_name, 'w') as f:
    json.dump(results, f, indent=2)

