import json
import random
import uuid
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch as t
from scipy.stats import pearsonr
from tqdm import tqdm

from hypo_interp.config import ExperimentConfig
from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit, CircuitEdge, EdgeType
from hypo_interp.utils.circuit_utils import (
    compute_actual_circuit_size,
    reformat_circuit_to_torch_index,
    sample_circuit_from_circuit,
    sample_inflated_circuit_from_circuit,
)
from hypo_interp.utils.test_utils import (
    free_memory,
    permutation_test,
    tail_test,
    wilcoxon_test,
)
from hypo_interp.utils.utils import load_with_pickle, save_with_pickle

################################################
#  Utility functions
#################################################


def _get_redundant_edge(
    inflated_circuit: Circuit,
    candidate_circuit: Circuit,
    edge_to_type: Dict[Tuple[CircuitEdge, bool], EdgeType],
) -> Tuple[CircuitEdge, bool]:
    """
    Returns all of the entries in the inflated circuit that are not in the
    candidate circuit.
    """
    # Make sure that both versions use torch index
    inflated_circuit = reformat_circuit_to_torch_index(inflated_circuit)
    candidate_circuit = reformat_circuit_to_torch_index(candidate_circuit)

    # Get rid of the placeholders
    inflated_circuit = [
        edge for edge in inflated_circuit if edge_to_type[edge[0]] != EdgeType.PLACEHOLDER
    ]
    candidate_circuit = [
        edge
        for edge in candidate_circuit
        if edge_to_type[edge[0]] != EdgeType.PLACEHOLDER
    ]

    # Check that the inflated circuit is a superset of the candidate circuit
    inflated_circuit = set(inflated_circuit)
    candidate_circuit = set(candidate_circuit)
    if not candidate_circuit.issubset(inflated_circuit):
        msg = "The candidate circuit is not a subset of the inflated circuit"
        msg += "\nCandidate circuit:"
        for edge in candidate_circuit:
            msg += f"\n\t{edge}"
        msg += "\nInflated circuit:"
        for edge in inflated_circuit:
            msg += f"\n\t{edge}"

        msg += "\nInflated - candidate:"
        for edge in inflated_circuit.difference(candidate_circuit):
            msg += f"\n\t{edge}"

        msg += "\nCandidate - inflated:"
        for edge in candidate_circuit.difference(inflated_circuit):
            msg += f"\n\t{edge}"
        raise ValueError(msg)

    # Get the redundant edges
    redundant_edges = inflated_circuit.difference(candidate_circuit)
    edge_to_rm = random.choice(list(redundant_edges))
    return edge_to_rm


def _compute_knock_out_score(
    knocked_out_circuit: Circuit,
    task: MechInterpTask,
    test_executor=None,
    inflated_circuit: Circuit = None,
    per_prompt: bool = True,
    mean: bool = True,
):
    """
    Computes the difference in score between the inflated
    circuit and the knocked out circuit. If the inflated circuit
    is not provided then the test_executor should be provided and
    it is assumed the candidate circuit is the inflated circuit.
    """
    # Make sure that either the inflated circuit or the test executor is provided
    if inflated_circuit is None and test_executor is None:
        raise ValueError(
            "Either the inflated circuit or the test executor must be provided"
        )

    inflated_score = None
    knocked_out_score = None

    if inflated_circuit is None:
        inflated_score, _ = test_executor.compute_candidate_score(
            per_prompt=per_prompt, invert=False
        )
    else:
        task.set_circuit(inflated_circuit, invert=False)
        inflated_score, _ = task.score(per_prompt=per_prompt)

    task.set_circuit(knocked_out_circuit, invert=False)
    knocked_out_score, _ = task.score(per_prompt=per_prompt)

    eval_metric = task.eval_metric(
        original_score=inflated_score, candidate_score=knocked_out_score, use_mean=mean
    )
    return eval_metric


#################################################
# Main class
#################################################
class TestExecutor:
    def __init__(
        self,
        task: MechInterpTask,
        config: ExperimentConfig,
        candidate_circuit: Circuit,
    ):
        """
        ModelDataManager: a class that handles model and dataset loading
        config: a class that handles the config
        """
        self.task = task
        self.config = config
        self.candidate_circuit = candidate_circuit

        # generate a unique id for this experiment
        self.uid = "executor_" + uuid.uuid4().hex[:8]

        self.config.task_name = self.task.__class__.__name__

        # Make output dir
        self.save_path = Path(self.config.save_path) / self.config.task_name
        self.save_path.mkdir(parents=True, exist_ok=True)

        # Set scores path
        if config.scores_path is None:
            self.scores_path = None
            self.load_scores = False
        else:
            self.scores_path = Path(config.scores_path)
            self.load_scores = True

        # Set seed
        if config.seed is not None:
            t.manual_seed(config.seed)
            random.seed(config.seed)

    @lru_cache(maxsize=None)  # cache the result of this function per arg call
    def compute_original_score(self, per_prompt=False, invert=False):
        self.task.set_circuit(
            self.task.complete_circuit, invert=invert
        )  # not needed but ok
        scores, logits = self.task.score(per_prompt=per_prompt)

        return scores, logits

    @lru_cache(maxsize=None)  # cache the result of this function per arg call
    def compute_candidate_score(self, per_prompt=False, invert=False):
        # Get the score of the candidate circuit
        self.task.set_circuit(self.candidate_circuit, invert=invert)
        scores, logits = self.task.score(per_prompt=per_prompt)
        return scores, logits

    def test_faithfulness(
        self,
        quantile=0.1,
        alpha=0.05,
        per_prompt=True,
        use_mean=False,
    ) -> Tuple[float, float, List[float]]:
        """
        Returns:
        - p_val (float): The p-value of the test.
        - real_eval_metric (float): How faithful the candidate circuit is to the original circuit.
        - simulated_eval_metrics (List[float]): How faithful random circuits are to the original circuit.
        -use_mean= false is a harder test, it takes the element wise difference first and then takes the mean
        -use_mean= true is an easier test, it takes the mean first and then takes the element wise difference

        This can then be used to compute the p-value of the real_eval_metric and

        """
        # experiment_dir = f"a-{alpha}_q-{quantile}_ppromt-{per_prompt}_mean-{use_mean}"
        output_dir = (
            self.save_path / "faithfulness" / f"invert_{self.config.invert}" / self.uid
        )
        output_dir.mkdir(parents=True, exist_ok=True)

        original_score, _ = self.compute_original_score(per_prompt=per_prompt)
        candidate_score, _ = self.compute_candidate_score(
            per_prompt=per_prompt, invert=self.config.invert
        )

        simulated_eval_metrics: List[float] = []
        real_eval_metric: float = self.task.eval_metric(
            original_score, candidate_score, use_mean=use_mean
        )

        circuit_size = compute_actual_circuit_size(
            self.task.canonical_circuit, use_pos_embed=self.task.use_pos_embed
        )
        if self.config.random_proportion is not None:
            complete_circuit_size = compute_actual_circuit_size(
                self.task.complete_circuit, use_pos_embed=self.task.use_pos_embed
            )
            circuit_size = complete_circuit_size * self.config.random_proportion

        for i in tqdm(range(self.config.num_random_circuits)):
            if self.load_scores:
                f = load_with_pickle(self.scores_path / f"random_{i}.pkl")
                random_circuit = f["random_circuit"]
                random_score = f["score"]
            else:
                random_circuit: Circuit = sample_circuit_from_circuit(
                    circuit=self.task.complete_circuit,
                    minimum_number_of_edges=circuit_size,
                    use_pos_embed=self.task.use_pos_embed,
                    seed=i,  # For determinism
                )
                self.task.set_circuit(random_circuit, invert=self.config.invert)
                random_score, logits = self.task.score(per_prompt=per_prompt)

                if self.config.save_scores:
                    scores_path = output_dir / "scores"
                    scores_path.mkdir(parents=True, exist_ok=True)
                    save = {"score": random_score, "random_circuit": random_circuit}
                    save_with_pickle(save, scores_path / f"random_{i}.pkl")

                free_memory(tensors_to_delete=[logits])  # can't delete the random score

            eval_metric = self.task.eval_metric(
                original_score, random_score, use_mean=use_mean
            )
            simulated_eval_metrics.append(eval_metric.item())
            print(f"Eval metric: {eval_metric.item()}")

        if len(simulated_eval_metrics) != self.config.num_random_circuits:
            raise ValueError(
                "Number of random circuits does not match number of simulated eval metrics"
            )

        if np.mean(simulated_eval_metrics) < 1e-2:
            raise ValueError("Simulated eval metrics are all 0")

        p_val, empirical_quantile = tail_test(
            base_distribution=np.array(simulated_eval_metrics),
            target=real_eval_metric.item(),
            quantile=quantile,
            alternative="less",  # because we want to see if the candidate circuit is more faithful than the random circuits, which incurs less loss
            direction="target>base",  # the claim we will make is on the targe loss more than the base loss, with probability
            return_empirical_quantile=True,
        )

        if p_val < alpha:
            print("The candidate circuit is faithful")
        # save config
        self.config.save_as_json(output_dir / "experiment_config.json")
        results = {
            "test": "faithfulness",
            "p-value": p_val,
            "empirical-quantile": empirical_quantile,
            "test-quantile": quantile,
            "use-mean": use_mean,
            "per-prompt": per_prompt,
        }
        with open(output_dir / "results.json", "w") as f:
            f.write(json.dumps(results, indent=4))
        return p_val, real_eval_metric, simulated_eval_metrics, results

    def test_minimality(self, quantile=0.6, per_prompt=True, use_mean=True):
        base_distribution_size = self.config.base_distribution_size_minimality
        num_edge_to_test = self.config.num_edge_to_test_minimality
        edge_to_type: Dict = self.task.edge_to_type

        # First we compute the scores for the inflated circuits
        base_knockout_scores = []
        for i in range(base_distribution_size):
            inflated_circuit = sample_inflated_circuit_from_circuit(
                circuit_to_inflate=self.candidate_circuit,
                complete_circuit=self.task.complete_circuit,
                use_pos_embed=self.task.use_pos_embed,
                seed=i,
                inflate_size=1,
            )
            edge_to_rm = _get_redundant_edge(
                inflated_circuit, self.candidate_circuit, edge_to_type
            )
            knocked_out_circuit = [
                edge for edge in inflated_circuit if edge != edge_to_rm
            ]
            knocked_out_score = _compute_knock_out_score(
                inflated_circuit=inflated_circuit,
                knocked_out_circuit=knocked_out_circuit,
                task=self.task,
                per_prompt=per_prompt,
                test_executor=self,
                mean=use_mean,
            )
            base_knockout_scores.append(knocked_out_score)

        # Now we compute the scores for the edges that we care about
        real_knockout_scores = []
        for edges_tested, edge in tqdm(enumerate(self.candidate_circuit)):
            # The edge is a tuple of (edge, bool) so we need to index into it
            # to get the actual edge
            if (
                edge_to_type[edge[0]] == EdgeType.PLACEHOLDER
            ):  # We don't care about placeholders we will get rid of them anyways
                continue
            knocked_out_circuit = [e for e in self.candidate_circuit if e[0] != edge[0]]
            knocked_out_score = _compute_knock_out_score(
                knocked_out_circuit=knocked_out_circuit,
                task=self.task,
                test_executor=self,
                per_prompt=per_prompt,
                mean=use_mean,
            )
            real_knockout_scores.append(knocked_out_score)

            if num_edge_to_test is not None and edges_tested >= num_edge_to_test:
                break

        # Finally we compute the p-value
        pvals = []
        emprical_quantiles = []
        for real_knockout_score in tqdm(real_knockout_scores):
            p_val, emprical_quantile = tail_test(
                base_distribution=np.array(base_knockout_scores),
                target=real_knockout_score.item(),
                quantile=quantile,
                alternative="greater",
                direction="target>base",
                return_empirical_quantile=True,
            )
            pvals.append(p_val)
            emprical_quantiles.append(emprical_quantile)

        results = {
            "p-value": pvals,
            "empirical-quantile": emprical_quantiles,
            "test-quantile": quantile,
            "per-prompt": per_prompt,
            "use-mean": use_mean,
        }

        return results

    def test_faithfulness_two_sample(self):
        per_prompt = True
        original_score, _ = self.compute_original_score(per_prompt=per_prompt)
        candidate_score, _ = self.compute_candidate_score(per_prompt=per_prompt)
        if len(original_score) != len(candidate_score):
            raise ValueError("original score and candidate score have different length")

        p_val = wilcoxon_test(target=candidate_score, base=original_score)
        output_dir = self.save_path / "wilcoxon" / self.uid
        output_dir.mkdir(parents=True, exist_ok=True)

        self.config.save_as_json(output_dir / "experiment_config.json")
        results = {
            "p-value": p_val,
            "test": "faithfulness_two_sample",
        }

        with open(output_dir / "results.json", "w") as f:
            f.write(json.dumps(results, indent=4))
        return p_val

    def test_sufficiency(
        self, label_score=None, num_permutations=1000
    ) -> Tuple[float, float, List[float]]:
        print("i am inside of test_sufficiency")
        # get the complement of the candidate circuit & get the score
        """
        Takes the complement of the candidate circuit and computes the score.
        Then, runs a permutation test against the label score.
        Returns: p-value, hsic test statistic
        """

        self.task.set_circuit(self.candidate_circuit, invert=True)
        # to run the sufficiency test, we need to use per_prompt=True
        complement_score, _ = self.task.score(per_prompt=True)

        # run permutation test against the label, in many cases, the label is the original circuit
        # return the correlation between the complement score and the label score

        if label_score is None:
            label_score, _ = self.compute_original_score(per_prompt=True)
            perm_res = permutation_test(
                complement_score, label_score, num_permutations=num_permutations
            )
        else:
            perm_res = permutation_test(
                complement_score, label_score, num_permutations=num_permutations
            )
        pearson_res = pearsonr(
            complement_score.detach().numpy(), label_score.detach().numpy()
        )

        hsic = perm_res["hsic"]
        p_value = perm_res["p_value"]
        simulated_statistics = perm_res["simulated_statistics"]

        output_dir = self.save_path / "independence" / self.uid
        output_dir.mkdir(parents=True, exist_ok=True)

        self.config.save_as_json(output_dir / "experiment_config.json")
        results = {
            "p-value": p_value,
            "hsic": hsic,
            "test": "indepedence_test",
            "pearson": pearson_res,
        }

        with open(output_dir / "results.json", "w") as f:
            f.write(json.dumps(results, indent=4))

        return p_value, hsic, simulated_statistics
