# SLOWER than base
import json

import torch
import torch.jit as jit
from project_utils.profiling import pytorch_profile
from torch import nn


def preprocess_links(data):
    num_nodes = len(data["nodes"])
    num_links = len(data["links"])

    # Tensors to store edges and their start indices
    edge_parent = torch.empty(num_links, dtype=torch.long)
    edge_child_start_index = torch.zeros(num_nodes + 1, dtype=torch.long)

    for idx, link in enumerate(data["links"]):
        link_idx = idx
        target = link["target"]
        source = link["source"]
        edge_parent[link_idx] = target  # Assuming each link has an 'index'
        edge_child_start_index[target + 1] += 1

    # Convert to cumulative start indices
    edge_child_start_index = torch.cumsum(edge_child_start_index, 0)

    return edge_parent, edge_child_start_index


def get_distributions(data, spn_class):
    all_distributions = {}
    # Initialize nodes based on their type (Bernoulli, Sum, Product)
    # We take reversed order because we want to evaluate the nodes from leaves to root
    for idx, each_node_dict in enumerate(data["nodes"]):
        node_class = each_node_dict["class"]
        spn_class.node_types[node_class].append(each_node_dict["id"])
        scope = torch.LongTensor([int(s) for s in each_node_dict.get("scope", [])])
        if node_class == "Bernoulli":
            params = each_node_dict["params"]["p"]
            all_distributions[idx] = BernoulliLeaf(
                idx, params, scope[0], spn_class.device
            )
        elif node_class == "Product":
            children_indices = spn_class._get_children_indices(idx)
            all_distributions[idx] = ProductNode(idx, scope, children_indices)

        elif node_class in ["Sum"]:
            children_indices = spn_class._get_children_indices(idx)
            all_distributions[idx] = SumNode(
                idx,
                each_node_dict["weights"],
                scope,
                children_indices,
                spn_class.device,
                spn_class.eps,
            )
        else:
            raise NotImplementedError(f"Unknown node class {node_class}")
    return all_distributions


class SPNModel(nn.Module):
    def __init__(
        self, json_file, num_var, device, percent_nodes_for_features, approx=False
    ):
        """
        Initialize the SPNModel with configuration from a JSON file.

        :param json_file: Path to a JSON file containing the model configuration.
        :param num_var: Number of variables in the model.
        :param device: Device to perform computations ('cpu' or 'cuda:0').
        :param depth_features: Depth of the features to be used for the input features of NN (defaults to 5).
        :param approx: Boolean flag for approximation method usage (defaults to False).
        """
        super(SPNModel, self).__init__()
        with open(json_file) as f:
            data = json.load(f)

        self.num_var = num_var
        self.num_nodes_in_spn: int = len(data["nodes"])
        self.all_nodes = torch.arange(self.num_nodes_in_spn, dtype=torch.long)
        # self.edges_for_each_idx = [[] for _ in range(self.num_nodes_in_spn)]
        self.approx = approx
        self.eps = 1e-6
        self.min_node_idx_for_features = self.num_nodes_in_spn - int(
            percent_nodes_for_features * self.num_nodes_in_spn
        )
        self.device = device
        self.node_types = {"Sum": [], "Product": [], "Bernoulli": []}
        edge_parent, edge_child_start_index = preprocess_links(data)
        self.edge_parent = edge_parent.to(device)
        self.edge_child_start_index = edge_child_start_index.to(device)

        distributions = get_distributions(data, self)
        self.all_distributions = nn.ModuleList(list(distributions.values()))
        # Initialize edges for each node

        # Convert node types to tensors
        for each_node_type in self.node_types:
            self.node_types[each_node_type] = torch.tensor(
                self.node_types[each_node_type], dtype=torch.long, device=self.device
            )

        self.node_types["SumBernoulli"] = torch.cat(
            (self.node_types["Sum"], self.node_types["Bernoulli"]), dim=-1
        )
        self.node_types["SumBernoulli"], _ = torch.sort(
            self.node_types["SumBernoulli"], descending=True
        )

        # find the last index from self.node_types["SumBernoulli"] which is less than self.max_node_idx_for_features
        sum_nodes_for_features_idx = torch.where(
            self.node_types["SumBernoulli"] > self.min_node_idx_for_features
        )[0]
        sum_nodes_for_features = self.node_types["SumBernoulli"][
            sum_nodes_for_features_idx
        ]
        self.lowest_sum_node_idx_for_features = sum_nodes_for_features[-1].item()
        self.num_sum_nodes_for_features = len(sum_nodes_for_features)
        # select values less than self.num_nodes_for_features
        self.node_types["SumBernoulli"] = self.node_types["SumBernoulli"][
            : self.num_sum_nodes_for_features
        ]
        # The input values will be bottom up

    def _get_children_indices(self, node_idx: int):
        start_idx = self.edge_child_start_index[node_idx]
        end_idx = self.edge_child_start_index[node_idx + 1]
        return self.all_nodes[start_idx + 1 : end_idx + 1]

    @torch.no_grad()
    def get_input_features(self, x, query_mask, unobs_mask):
        """
        Get the input features of the SPN model.

        :param x: Tensor representing input data (batch_size, input_size).
        :return: Tensor of input features (batch_size, num_var).
        """

        # We only use the first self.depth_features features
        # the shape of function_values_at_each_index is (num_var, batch_size)
        function_values_at_each_index = torch.zeros(
            (self.num_nodes_in_spn, x.size(0)), device=self.device
        )
        # make query vars -1 and unobs vars nan
        # required to process through NN
        x = x.clone()  # Clone to avoid modifying the original tensor in-place
        x[query_mask] = -1
        x[unobs_mask] = float("nan")

        # Evaluate each node in reverse order (from leaves to root)
        for node_idx in reversed(range(self.num_nodes_in_spn)):
            if node_idx < self.lowest_sum_node_idx_for_features:
                break
            func_value = self._evaluate_node(node_idx, x, function_values_at_each_index)
            function_values_at_each_index[node_idx] = func_value
        # Select node values for which node type is sum or bernoulli in sorted order
        function_values_at_each_index = function_values_at_each_index[
            self.node_types["SumBernoulli"]
        ]
        # Transpose to get (batch_size, num_var)
        function_values_at_each_index = torch.t(function_values_at_each_index)

        return function_values_at_each_index

    def evaluate(self, x):
        """
        Evaluate the SPN model on input data.

        :param x: Tensor representing input data (batch_size, input_size).
        :return: Function value at the root of the SPN.
        """
        # the shape of function_values_at_each_index is (num_var, batch_size)
        function_values_at_each_index = torch.empty(
            (self.num_nodes_in_spn, x.size(0)), device=self.device
        )
        # Make sure there are no -1s in the input
        assert torch.sum(x == -1) == 0, "Input cannot contain -1s"
        # Evaluate each node in reverse order (from leaves to root)
        for node_idx in reversed(range(self.num_nodes_in_spn)):
            func_value = self._evaluate_node(
                node_idx,
                x,
                function_values_at_each_index,
            )
            function_values_at_each_index[node_idx] = func_value
        return function_values_at_each_index[0]

    def _get_node_type_and_children_indices(self, node_idx: int):
        """
        Get the type and children indices of a node.

        :param node_idx: Index of the node.
        :return: Tuple of node type and tensor of children indices.
        """
        # children_indices = torch.tensor(
        #     self.edges_for_each_idx[node_idx], dtype=torch.long, device=self.device
        # )
        children_indices = self._get_children_indices(node_idx)

        return children_indices

    def _evaluate_node(self, node_idx: int, x, function_values_at_each_index):
        """
        Evaluate a single node in the SPN.

        :param node_type: Type of the node (Sum, Product, Bernoulli).
        :param children_indices: Indices of the children nodes.
        :param node_idx: Index of the current node.
        :param x: Input data tensor.
        :param function_values_at_each_index: Tensor holding values for each node.
        :return: Evaluated value of the node.
        """
        node = self.all_distributions[node_idx]
        if isinstance(node, ProductNode):
            # child_values = function_values_at_each_index[children_indices]
            return node(function_values_at_each_index)
        elif isinstance(node, SumNode):
            # weights = self.all_distributions[node_idx][2]
            # log_weights = torch.log(weights + self.eps)
            # orignal_child_values = function_values_at_each_index[children_indices]
            return node(function_values_at_each_index)
        elif isinstance(node, BernoulliLeaf):
            return node(x)
        else:
            raise NotImplementedError(f"Unknown node type")


import torch
import torch.nn as nn


class ProductNode(nn.Module):
    def __init__(self, node_idx, scope, children_indices):
        super(ProductNode, self).__init__()
        self.node_idx = node_idx
        self.scope = scope
        self.children_indices = children_indices

    def forward(self, function_values):
        """
        Perform the product operation.

        :param child_values: Tensor of values from child nodes.
        :return: Result of the product operation. (sum in log space)
        """
        child_values = function_values[self.children_indices]
        curr_node_value = torch.sum(child_values, dim=0)
        return curr_node_value


class SumNode(nn.Module):
    def __init__(self, node_idx, weights, scope, children_indices, device, eps=1e-6):
        super(SumNode, self).__init__()
        self.node_idx = node_idx
        self.weights = torch.FloatTensor(weights).to(
            device
        )  # Make weights a learnable parameter
        self.log_weights = torch.log(self.weights + eps).unsqueeze(1)
        self.scope = scope
        self.children_indices = children_indices
        self.eps = eps

    def forward(self, function_values):
        """
        Perform the sum operation in log space.

        :param child_values: Tensor of values from child nodes.
        :return: Result of the sum operation in log space.
        """
        child_values = function_values[self.children_indices]
        weighted_child_values = child_values + self.log_weights
        curr_node_value = torch.logsumexp(weighted_child_values, dim=0)
        return curr_node_value


# The BernoulliFunction class is a PyTorch module that performs computations on Bernoulli random
# variables, converting them to log space to avoid underflow and handling NaN values.
class BernoulliLeaf(torch.nn.Module):
    def __init__(self, node_idx, p, scope, device):
        super(BernoulliLeaf, self).__init__()
        self.node_idx = node_idx
        self.weights = torch.tensor([p], dtype=torch.float, device=device)
        self.eps = 1e-6
        self.scope = scope
        self.device = device

        # Precompute operations
        self.precomputed_values = self.precompute()

    def precompute(self):
        # Precompute log weights and related values
        log_weights = torch.log(self.weights + self.eps).to(self.device)
        log_one_minus_weights = torch.log(1 - self.weights + self.eps).to(self.device)

        # Precompute results for NaN cases
        result_for_nan = torch.logsumexp(
            torch.stack([log_weights, log_one_minus_weights]), dim=0
        )

        # Precompute results for input cases
        if self.weights > 0.5:
            result_for_query = log_weights
        else:
            result_for_query = log_one_minus_weights

        return {
            "log_weights": log_weights,
            "log_one_minus_weights": log_one_minus_weights,
            "result_for_nan": result_for_nan,
            "result_for_query": result_for_query,
        }

    def forward(self, x):
        # if value is -1, it is a query variable - used when we want to find inputs for NN
        # if value is nan, it is a unobs value

        # This reduces a dimension since we are only interested in one variable for each bernoulli node

        if x.ndimension() == 1:
            # add a first dimension to x - if it is 1d - this is one example and not a batch
            x = x.unsqueeze(0)
        x = x[:, self.scope]

        # Handling NaN values
        unobs_indices = torch.isnan(x)
        x[unobs_indices] = 0  # Temporarily replace NaNs for computation
        # check if x is 2d or 3d, if it is 2d then use the following
        # else if it is 3d then third dimension is (x and 1-x) with 1-x at index 0
        if x.ndimension() == 1:
            # Input was 2d become 1d after indexing
            # Log probabilities calculation using precomputed values
            log_prob_one_minus_x = torch.log(1 - x + self.eps)
            log_prob_x = torch.log(x + self.eps)
        elif x.ndimension() == 2:
            # Input was 3d become 2d after indexing

            # For model which outputs softmax
            log_prob_one_minus_x = torch.log(x[:, 0] + self.eps)
            log_prob_x = torch.log(x[:, 1] + self.eps)
            # We only need to consider the first dimension since both will be nan
            unobs_indices = unobs_indices[:, 0]
        result = torch.logsumexp(
            torch.stack(
                [
                    log_prob_x + self.precomputed_values["log_weights"],
                    log_prob_one_minus_x
                    + self.precomputed_values["log_one_minus_weights"],
                ]
            ),
            dim=0,
        )

        # Replace results for NaN values
        result_in_log_space = torch.where(
            unobs_indices, self.precomputed_values["result_for_nan"], result
        )
        # query_result is equal to log_weight if self.weight > 0.5 else log_one_minus_weight
        query_results = self.precomputed_values["result_for_query"]

        # query vars are made -1
        if len(x.shape) == 1:
            result_in_log_space = torch.where(
                x != -1, result_in_log_space, query_results
            )
        elif len(x.shape) == 2:
            result_in_log_space = torch.where(
                x[:, 0] != -1, result_in_log_space, query_results
            )
        return result_in_log_space
