from functools import partial

import torch as t
from acdc.acdc_utils import logit_diff_metric
from acdc.ioi.utils import get_all_ioi_things, get_ioi_true_edges

from hypo_interp.tasks.ioi.ioi_dataset import IOIDataset as original_IOIDataset
from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit

##################
# Helper functions
##################


def get_true_circuit(model) -> Circuit:
    """
    just loading the true circuit from acdc.ioi.utils
    """
    get_true_edges = partial(get_ioi_true_edges, model=model)
    d_trues = get_true_edges()
    true_edges = []
    for k, v in d_trues.items():
        true_edges.append((k, v))
    return true_edges


#################
# Main Class
#################


class IoITask(MechInterpTask):
    """
    IOI task from the paper.
    """

    def __init__(
        self,
        zero_ablation: bool = False,
        device: str = "cuda",
        num_examples: int = 2,
        metric_name: str = "logit_diff",
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(zero_ablation=zero_ablation, device=device)

        # load in a tl_model and grab some data
        all_things = get_all_ioi_things(
            num_examples=num_examples, device=device, metric_name=metric_name
        )

        # Init abstract class attributes
        self._validation_metric = all_things.validation_metric
        # ablation data is the patch data in ioi, tho could be test_patch_data as well.
        self._ablate_dataset = all_things.validation_patch_data
        self._base_dataset = all_things.validation_data
        self._validation_labels = all_things.validation_labels
        self.num_examples = num_examples

        self._experiment = self._make_experiment(
            base_dataset=self._base_dataset,
            ablate_dataset=self._ablate_dataset,
            model=all_things.tl_model,
            validation_metric=self._validation_metric,
            zero_ablation=self._zero_ablation,
            use_pos_embed=self.use_pos_embed,
        )

        # Other attributes relevant for score
        self._validation_mask = all_things.validation_mask
        self._validate_attributes()
        print("********************** task initialized")

    def score(self, use_acdc: bool = True, per_prompt: bool = True) -> t.Tensor:
        """
        Returns the score of the current circuit.
        """
        logits = self._experiment.model(self._base_dataset, return_type="logits")

        if not use_acdc:
            ioi_dataset = original_IOIDataset(
                prompt_type="ABBA",
                N=self.num_examples * 2,
                nb_templates=1,
                seed=0,
            )

            seq_len = ioi_dataset.toks.shape[1]
            labels = ioi_dataset.toks.long()[: self.num_examples * 2, seq_len - 1]
            wrong_labels = t.as_tensor(
                ioi_dataset.s_tokenIDs[: self.num_examples * 2],
                dtype=t.long,
                device="cuda",
            )
            validation_labels = labels[: self.num_examples]
            validation_wrong_labels = wrong_labels[: self.num_examples]

            scores = logit_diff_metric(
                logits,
                validation_labels,
                validation_wrong_labels,
                return_one_element=not per_prompt,
            )
        else:
            scores = self._validation_metric(
                logits,
                return_one_element=not per_prompt,
            )
        return scores, logits

    @property
    def _canonical_circuit(self) -> Circuit:
        circuit: Circuit = get_true_circuit(self._experiment.model)
        return circuit
