import torch as t
from acdc.ioi.ioi_dataset import IOIDataset

from hypo_interp.tasks import IoITask

DEVICE = "cpu"


def logit_diff_metric(
    logits, correct_labels, wrong_labels, return_one_element: bool = True
) -> t.Tensor:
    range = t.arange(len(logits))
    correct_logits = logits[range, -1, correct_labels]
    incorrect_logits = logits[range, -1, wrong_labels]

    # Note: negative sign so we minimize
    # TODO de-duplicate with docstring/utils.py `raw_docstring_metric`
    if return_one_element:
        return -(correct_logits.mean() - incorrect_logits.mean())
    else:
        return -(correct_logits - incorrect_logits).view(-1)


def test_score_consistency():
    task = IoITask(num_examples=1, device=DEVICE)
    task.set_circuit(task.complete_circuit)

    scores, logits = task.score()
    num_examples = 1

    # hard code score function
    ioi_dataset = IOIDataset(
        prompt_type="ABBA",
        N=2,
        nb_templates=num_examples,
        seed=0,
    )
    seq_len = ioi_dataset.toks.shape[1]
    assert seq_len == 16, f"Well, I thought ABBA #1 was 16 not {seq_len} tokens long..."

    labels = ioi_dataset.toks.long()[: num_examples * 2, seq_len - 1]
    wrong_labels = t.as_tensor(
        ioi_dataset.s_tokenIDs[: num_examples * 2], dtype=t.long, device=DEVICE
    )
    validation_labels = labels[:num_examples]
    validation_wrong_labels = wrong_labels[:num_examples]

    scores_new = logit_diff_metric(
        logits, validation_labels, validation_wrong_labels, return_one_element=True
    )
    assert scores_new == scores, "the scores should be the same"


def test_per_prompt():
    task = IoITask(num_examples=2, device=DEVICE)
    task.set_circuit(task.complete_circuit)

    scores, logits = task.score(per_prompt=True)
    assert len(scores) == 2, "the per_prompt should be turned on"
