import networkx as nx
import torch
from torch.distributions import Independent, Normal

import causal_nf.utils.io as causal_io
from causal_nf.datasets.ihdp import IHDPDataset
from causal_nf.preparators.tabular_preparator import TabularPreparator
from causal_nf.sem_equations import sem_dict
from causal_nf.utils.io import dict_to_cn
from causal_nf.utils.scalers import StandardTransform


class IHDPPreparator(TabularPreparator):
    def __init__(self, add_noise, **kwargs):

        self.dataset = None
        self.add_noise = add_noise
        sem_fn = sem_dict["ihdp"](sem_name="dummy")

        self.adjacency = sem_fn.adjacency

        self.num_nodes = len(sem_fn.functions)

        self.intervention_index_list = sem_fn.intervention_index_list()

        super().__init__(name="ihdp", task="modeling", **kwargs)

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
            "add_noise": dataset.add_noise,
        }

        my_dict.update(TabularPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = IHDPPreparator.params(dataset)

        return cls(**my_dict)

    def _x_dim(self):
        return self.num_nodes

    def get_intervention_list(self):
        int_list = []

        for i in self.intervention_index_list:
            values_i = []
            values_i.append({"name": f"t=1", "value": 1.0})
            values_i.append({"name": f"t=0", "value": 0.0})

            for value in values_i:
                value["value"] = round(value["value"], 2)
                value["index"] = i
                int_list.append(value)

        return int_list

    def diameter(self):
        adjacency = self.adjacency(True).numpy()
        G = nx.from_numpy_matrix(adjacency, create_using=nx.Graph)
        diameter = nx.diameter(G)
        return diameter

    def longest_path_length(self):
        adjacency = self.adjacency(False).numpy()
        G = nx.from_numpy_matrix(adjacency, create_using=nx.DiGraph)
        longest_path_length = nx.algorithms.dag.dag_longest_path_length(G)
        return int(longest_path_length)

    def get_ate_list(self):

        int_list = []
        for i in self.intervention_index_list:

            values_i = []
            values_i.append({"name": "1_0", "a": 1.0, "b": 0.0})
            for value in values_i:
                value["a"] = round(value["a"], 2)
                value["b"] = round(value["b"], 2)
                value["index"] = i
                int_list.append(value)

        return int_list

    def intervene(self, index, value, shape):
        if len(shape) == 1:
            shape = (shape[0], 27)
        x_int = torch.randn(shape)
        x_int[:, index] = value
        return x_int

    def compute_ate(self, index, a, b, num_samples=1000):
        assert index == 25, f"Index is not the treatment index: {index}"
        assert a == 1.0, f"a is not 1.0: {a}"
        assert b == 0.0, f"b is not 0.0: {b}"
        mu1_list, mu0_list = [], []
        for dataset in self.datasets:
            mu1_list.append(dataset.mu_1)
            mu0_list.append(dataset.mu_0)
        mu1 = torch.cat(mu1_list, dim=0)
        mu0 = torch.cat(mu0_list, dim=0)

        ate = mu1 - mu0

        return ate.mean(0)

    def compute_counterfactual(self, x_factual, index, value):
        assert index == 25, f"Index is not the treatment index: {index}"
        assert value in [0.0, 1.0], f"Value is not 0.0 or 1.0: {value}"

        x_cf = torch.randn_like(x_factual)
        x_cf[:, index] = value

        causal_io.print_warning(f"This is not implemented")

        return x_cf

    def log_prob(self, x):
        px = Independent(
            Normal(
                torch.zeros(27),
                torch.ones(27),
            ),
            1,
        )
        return px.log_prob(x)

    def _loss(self, loss):
        if loss in ["default", "forward"]:
            return "forward"
        else:
            raise NotImplementedError(f"Wrong loss {loss}")

    def _split_dataset(self, dataset_raw):
        datasets = []

        for i, split_s in enumerate(self.split):
            dataset = IHDPDataset(
                root_dir=self.root, split=self.split_names[i], seed=self.k_fold
            )

            dataset.prepare_data()
            dataset.set_add_noise(self.add_noise)
            if i == 0:
                self.dataset = dataset
            datasets.append(dataset)

        return datasets

    def _get_dataset(self, num_samples, split_name):
        raise NotImplementedError

    def get_scaler(self, fit=True):

        scaler = self._get_scaler()
        self.scaler_transform = None
        if fit:
            x = self.get_features_train()
            scaler.fit(x, dims=self.dims_scaler)
            if self.scale in ["default", "std"]:
                self.scaler_transform = StandardTransform(
                    shift=x.mean(0), scale=x.std(0)
                )
                print("scaler_transform", self.scaler_transform)

        self.scaler = scaler

        return scaler

    def get_scaler_info(self):
        if self.scale in ["default", "std"]:
            return [("std", None)]
        else:
            raise NotImplementedError

    @property
    def dims_scaler(self):
        return (0,)

    def _get_dataset_raw(self):
        return None

    def _transform_dataset_pre_split(self, dataset_raw):
        return dataset_raw

    def post_process(self, x, inplace=False):
        if not inplace:
            x = x.clone()
        dims = self.dataset.binary_dims
        min_values = self.dataset.binary_min_values
        max_values = self.dataset.binary_max_values

        x[..., dims] = x[..., dims].floor().float()
        x[..., dims] = torch.clamp(x[..., dims], min=min_values, max=max_values)

        return x

    def feature_names(self, latex=False):
        return self.dataset.column_names
        # if latex:
        #     return [f"$x_{{{i + 1}}}$" for i in range(self.x_dim())]
        # else:
        #     return [f"x_{i + 1}" for i in range(self.x_dim())]

    def _plot_data(
        self,
        batch=None,
        title_elem_idx=None,
        batch_size=None,
        df=None,
        hue=None,
        **kwargs,
    ):

        title = f"\\textsc{{{self.name.upper()}}}"

        return super()._plot_data(
            batch=batch,
            title_elem_idx=title_elem_idx,
            batch_size=batch_size,
            df=df,
            title=title,
            hue=hue,
        )
