import numpy as np
import pytest
import torch as t

from hypo_interp.config import ExperimentConfig
from hypo_interp.tasks import (
    GreaterThanTask,
    InductionTask,
    IoITask,
    TracrProportionTask,
    TracrReverseTask,
)
from hypo_interp.tasks.docstring.task import DocstringTask
from hypo_interp.test_executor import TestExecutor

DEVICE = "cpu"
t.set_grad_enabled(False)


# We don't include the Tracr tasks here because they are used in
# a different test.
@pytest.mark.parametrize(
    "task_name, task_class, zero_ablation, num_examples",
    [
        ("InductionTask", InductionTask, True, 10),
        ("GreaterThanTask", GreaterThanTask, False, 10),
        ("IoITask", IoITask, False, 10),
        ("DocstringTask", DocstringTask, False, 10),
    ],
)
@pytest.mark.slow
def test_minimality_main(task_name, task_class, zero_ablation, num_examples):
    """
    This test simply checks that the minimality test
    runs without error on all tasks.
    """
    base_distribution_size_minimality = 10
    num_edge_to_test_minimality = 2
    quantile = 0.6

    task = task_class(
        zero_ablation=zero_ablation, device=DEVICE, num_examples=num_examples
    )

    config = ExperimentConfig(
        device=DEVICE,
        base_distribution_size_minimality=base_distribution_size_minimality,
        num_edge_to_test_minimality=num_edge_to_test_minimality,
    )
    handler = TestExecutor(
        config=config, task=task, candidate_circuit=task.canonical_circuit
    )
    results = handler.test_minimality(quantile=quantile)
    assert True


@pytest.mark.parametrize(
    "task_name, task_class",
    [
        ("TracrProportionTask", TracrProportionTask),
        ("TracrReverseTask", TracrReverseTask),
    ],
)
def test_minimality_tracr_works(task_name, task_class):
    """
    Very similar to the test above but here we have
    a better idea of how things behave. The reason is that
    tracr being a compiler model has a very specific structure
    that allows us to guess that the minimality test should
    pass really really well (i.e the Null hypothesis should be
    rejected with a very high p-value)
    """
    base_distribution_size_minimality = 10
    num_edge_to_test_minimality = 2

    task = task_class(device=DEVICE, zero_ablation=False)

    config = ExperimentConfig(
        device=DEVICE,
        base_distribution_size_minimality=base_distribution_size_minimality,
        num_edge_to_test_minimality=num_edge_to_test_minimality,
    )
    handler = TestExecutor(
        config=config, task=task, candidate_circuit=task.canonical_circuit
    )
    results = handler.test_minimality(quantile=0.6)

    mean_empirical_quantile = np.array(results["empirical-quantile"]).mean()
    mean_p_value = np.array(results["p-value"]).mean()

    assert mean_empirical_quantile > 0.9
    assert mean_p_value < 0.05
