import json
import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm
tqdm.pandas(desc='Progress')

import openai_api
import anthropic_api
from persona import persona_eval, persona_eval_sys

sys_message = 'You are a helpful and precise assistant for checking the quality of the AI assistant\'s responses in conversations.'

sys_prompt = (
'Please evaluate the above conversations between user and AI assistant by using the following metrics:\n'
'Fluency (5-point Likert): How clear (or fluent) were the responses from the AI Assistant?\n'
'Helpfulness (5-point Likert): Independent of its fluency, how helpful was having access to the AI Assistant compared to not having access?\n'
'Ease of interaction (5-point Likert): How easy was it to interact with the AI Assistant?\n'
'Helpfulness (free-form): Why did you find the AI Assistant helpful or unhelpful?\n'
'Please output each of the above metrics line-by-line.'
)
#if 'general' != config['persona']:
#    sys_prompt = '\n'.join([persona_eval[config['persona']], sys_prompt])

#data = pd.read_csv('../data/event_blocks.csv')
#workers = pd.read_csv('../results/accuracy_by_id.csv')
file_name = '../results/conversation_{model1}_{model2}{persona}_prompt-1.csv'.format(
        model1=config['model1'],
        model2=config['model2'],
        persona='_'+config['conv_persona'] if 'general'!=config['conv_persona'] else ''
#        persona='type5'
        )
data = pd.read_csv(file_name)

# only 'lm' has interaction with language models
data = data['lm' == data['question_type']].groupby('worker_id')
data = pd.concat([d for name, d in data][:60])
data = data.groupby('worker_id')

def extract_line(line):
    question_text = 'Question:\n{q}\nA.{a}\nB.{b}\nC.{c}\nD.{d}'.format(
                        q=line['question_text'],
                        a=line['choice_a'],
                        b=line['choice_b'],
                        c=line['choice_c'],
                        d=line['choice_d'],
                    )
    answer_golden = 'True Answer: {ans}'.format(ans=line['answer_text'])

    conversation = ['Conversation:']
    for query, response in zip(eval(line['user_queries']), eval(line['lm_responses'])):
        turn = 'User: {up}\nAI Assistant: {ar}'.format(up=query, ar=response)
        conversation.append(turn)
    conversation = '\n'.join(conversation)

    options = {'a': line['choice_a'],
               'b': line['choice_b'],
               'c': line['choice_c'],
               'd': line['choice_d']}
    if pd.isnull(line['user_answer']) or not isinstance(line['user_answer'], str):
        user_answer = ''
    elif line['user_answer'].lower() not in ['a', 'b', 'c', 'd']:
        user_answer = ''
    else:
        user_answer = options[line['user_answer'].lower()]
    answer_user = 'User Answer: {ans}'.format(ans=user_answer)

    message = '\n'.join([question_text, answer_golden, conversation, answer_user])
    line['model_message'] = message
    return line

def extract(group):
    messages = [{'role': 'system', 'content': sys_message}]

    global sys_prompt
    if 'general' != config['persona']:
#        worker_id = group['worker_id'].unique()[0]
#        rate = workers.loc[workers['worker_id'] == worker_id]['rate'].unique()[0]
#        if rate > 0.6:
#            persona_type = 'type5'
#        elif rate > 0.3:
#            persona_type = 'type6'
#        else:
#            persona_type = 'type4'
#        persona_type = 'type4' if rate < 0.7 else 'type5'
#        sys_prompt = '\n'.join([persona_eval[persona_type], sys_prompt])
        sys_prompt = persona_eval_sys[config['persona']]

    group = pd.DataFrame(group)
    group = group.apply(extract_line, axis=1)
    prompt = '\n\n'.join(group['model_message'].tolist())
    prompt = '\n\n'.join([prompt, sys_prompt])
    messages.append({'role': 'user', 'content': prompt})

    if 'claude' == config['eval_model'][:6]:
        prediction = anthropic_api.call(messages)
    elif 'gpt-' == config['eval_model'][:4]:
        prediction = openai_api.call_chat(messages, config['eval_model'])
    else:
        prompt = '\n\n'.join([messages[0]['content'], messages[1]['content']])
        prediction = openai_api.call_completion(prompt, config['eval_model'])
    line = {'worker_id': group.iloc[0]['worker_id'],
            'question': group.iloc[0]['question_text'],
            'prediction': prediction,
            'no_of_turns': len(eval(group.iloc[0]['user_queries'])),
            }
    return pd.Series(line)

predictions = data.progress_apply(extract)
predictions = predictions.to_json(orient='records', lines=True).splitlines()
predictions = [eval(pred) for pred in predictions]

#with open('../results/predictions_claude.json', 'w') as f:
#    json.dump(predictions, f, indent=2)
#file_name = '../results/predictions_{eval_model}_{persona}_prompt-2_first-60.json'.format(
#        eval_model=config['eval_model'],
#        persona=config['persona']
#        )
if config['persona'] == config['conv_persona']:
    file_name = '../results/predictions_{model1}_{model2}_{eval_model}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
else:
    file_name = '../results/predictions_{model1}{conv_persona}_{model2}_{eval_model}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else '',
        conv_persona='_'+config['conv_persona'] if 'general'!=config['conv_persona'] else '',
        )
with open(file_name, 'w') as f:
    json.dump(predictions, f, indent=2)

