import json
from dataclasses import dataclass

from tqdm import tqdm

from adapter import GPT3Adapter
from data import CausalDataset, MoralDataset, Example, AbstractDataset, JsonSerializable
from evaluator import AccuracyEvaluatorWithAmbiguity, CorrelationEvaluator, RMSEEvaluator, AuROCEvaluator
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, CausalAbstractJudgmentPrompt, MoralAbstractJudgmentPrompt, \
    JudgmentPrompt, AbstractJudgmentPrompt
from thought_as_text_translator import MoralTranslator, CausalTranslator, Translator


@dataclass
class ExperimentResult(JsonSerializable):
    acc: float
    conf_interval: tuple[float, float]
    r: float
    p: float
    rmse: float
    auroc: float

def run_template_for_gpt3(cd: AbstractDataset, adapter: GPT3Adapter,
                          jp: JudgmentPrompt, ajp:AbstractJudgmentPrompt,
                          translator: Translator,
                          method: str='yesno'):
    all_choice_scores, all_label_dist = [], []
    ex: Example
    for ex in tqdm(cd):
        if len(ex.annotated_sentences) == 0:
            instance = jp.apply(ex)
        else:
            abs = translator.translate_example(ex)
            instance = ajp.apply(abs)

        choice_scores = adapter.adapt(instance, method=method)
        all_choice_scores.append(choice_scores)
        all_label_dist.append(ex.answer_dist)

    return all_choice_scores, all_label_dist

def exp2_causal(engine: str='text-davinci-002'):
    cd = CausalDataset()
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices =  run_template_for_gpt3(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt.jinja"),
                                                    CausalAbstractJudgmentPrompt("./prompts/exp1_causal_prompt.jinja"),
                                                    translator=CausalTranslator(),
                                                    method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices =  run_template_for_gpt3(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt_2.jinja"),
                                                    CausalAbstractJudgmentPrompt("./prompts/exp1_causal_prompt_2.jinja"),
                                                    translator=CausalTranslator(),
                                                    method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Causal Abstract Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Causal Correlation: {r:.4f} (p={p:.4f})")
    print(f"Causal RMSE: {rmse:.4f}")
    print(f"Causal AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp2_moral(engine="text-davinci-002"):
    md = MoralDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_gpt3(md, adapter,
                                                         MoralJudgmentPrompt("./prompts/exp1_moral_prompt.jinja"),
                                                         MoralAbstractJudgmentPrompt(
                                                             "./prompts/exp1_moral_prompt.jinja"),
                                                         translator=MoralTranslator(),
                                                         method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_gpt3(md, adapter,
                                                         MoralJudgmentPrompt("./prompts/exp1_moral_prompt_2.jinja"),
                                                         MoralAbstractJudgmentPrompt(
                                                             "./prompts/exp1_moral_prompt_2.jinja"),
                                                         translator=MoralTranslator(),
                                                         method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Moral Abstract Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Moral Correlation: {r:.4f} (p={p:.4f})")
    print(f"Moral RMSE: {rmse:.4f}")
    print(f"Moral AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def produce_table2():
    result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        er = exp2_moral(engine=engine)
        result[engine] = er.json

    json.dump(result, open('../../results/exp2_moral_full_result.json', 'w'), indent=2)

    result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        er = exp2_causal(engine=engine)
        result[engine] = er.json

    json.dump(result, open('../../results/exp2_causal_full_result.json', 'w'), indent=2)


if __name__ == '__main__':
    ...
    produce_table2()