import os

import numpy as np
import torch
from bn_class_log import BinaryBNModel
from deeprob.spn.algorithms.inference import mpe
from deeprob.spn.structure.io import load_spn_json
from get_spn_class_log import SPNModel
from loguru import logger
from made import MADE
from mn_class_log import BinaryMNModel
from tqdm import tqdm


def get_pgm_loss(cfg, device, num_outputs):
    if cfg.pgm == "spn":
        cfg.pgm_model_path = os.path.join(
            cfg.pgm_model_directory, f"{cfg.dataset}/spn.json"
        )
        torch_pgm = SPNModel(
            cfg.pgm_model_path,
            num_var=num_outputs,
            device=device,
            percent_nodes_for_features=cfg.percent_nodes_for_features,
        )
        library_pgm = load_spn_json(cfg.pgm_model_path)
    elif cfg.pgm == "bn":
        cfg.pgm_model_path = os.path.join(
            cfg.pgm_model_directory, f"{cfg.dataset}.uai"
        )
        torch_pgm = BinaryBNModel(
            cfg.pgm_model_path,
            device=device,
        )
        library_pgm = None
    elif cfg.pgm == "mn":
        cfg.pgm_model_path = os.path.join(
            cfg.pgm_model_directory, f"{cfg.dataset}.uai"
        )
        torch_pgm = BinaryMNModel(
            cfg.pgm_model_path,
            device=device,
        )
        library_pgm = None
    elif cfg.pgm == "made":
        cfg.pgm_model_path = os.path.join(
            cfg.pgm_model_directory, f"{cfg.dataset}/made.pt"
        )

        torch_pgm = MADE(
            num_outputs,
            [
                512,
                1024,
            ],
            num_outputs,
            1,
        ).to(device)
        # Load MADE model
        cfg.pgm_model_path = cfg.pgm_model_path.replace("spn", "made")
        torch_pgm.load_state_dict(torch.load(cfg.pgm_model_path))
        library_pgm = None
    else:
        raise NotImplementedError(
            f"{cfg.pgm} not implemented, use spn, made, bn, or mn"
        )
    logger.info(f"Loaded PGM model from {cfg.pgm_model_path}")
    # don't allow gradient updates for the PGM
    for param in torch_pgm.parameters():
        param.requires_grad = False
    return torch_pgm, library_pgm


def get_num_features(cfg, train_data, torch_pgm):
    if cfg.input_type in ["data", "dataSpn"]:
        # we have data features only for [data, data_spn]
        num_data_features = train_data.shape[1]
    else:
        num_data_features = 0
    if cfg.input_type in ["spn", "dataSpn"]:
        # we have spn features only for [spn, data_spn]
        num_spn_feature = torch_pgm.num_sum_nodes_for_features
    else:
        num_spn_feature = 0
    return num_data_features, num_spn_feature


def init_embedding_sizes(cfg, num_nodes_in_graph):
    if (
        cfg.sin_embedding_size is None
        or cfg.sin_embedding_size < num_nodes_in_graph
        and cfg.embedding_type == "continuousSin"
    ):
        cfg.sin_embedding_size = num_nodes_in_graph * 10
    if (
        cfg.latent_dim is None
        or cfg.latent_dim < num_nodes_in_graph
        and cfg.model == "vae"
    ):
        cfg.latent_dim = num_nodes_in_graph * 10


def get_true_mmap_torch(torch_spn, buckets, example, device):
    """
    The function `get_true_mmap_torch` takes a torch_spn, buckets, example, and device as input and
    returns the best query fill and its corresponding log-likelihood score.

    :param torch_spn: The `torch_spn` parameter is a torch-based Sum-Product Network (SPN) model. It is
    used to evaluate the log-likelihood scores for different data points
    :param buckets: The "buckets" parameter is a dictionary that contains information about the
    variables in the SPN (Sum-Product Network). Specifically, it contains a boolean array called
    "query_mask" which indicates which variables are query variables. The "query_mask" array has
    the same length as the number of
    :param example: The `example` parameter is a tensor representing a single data point. It is expected
    to have the same shape as the input data used to train the `torch_spn` model
    :param device: The "device" parameter specifies the device on which the computation will be
    performed. It can be either "cpu" or "cuda" depending on whether you want to use the CPU or GPU for
    computation
    :return: The function `get_true_mmap_torch` returns two values: `best_query_fill` and `best_ll`.
    """
    query_mask = buckets["query"]
    num_possible_query = torch.sum(query_mask)
    data = example.to(device)

    # Generate binary vectors using torch
    binary_vectors = torch.zeros(
        (2**num_possible_query, num_possible_query), device=device
    )
    for i in range(2**num_possible_query):
        for j in range(num_possible_query):
            binary_vectors[i, j] = (i >> j) % 2

        # Repeat the data tensor to match the number of query fill combinations
    all_data_points = data.repeat(binary_vectors.shape[0], 1)

    # Fill the query variables with all possible query fill combinations
    all_data_points[:, query_mask] = binary_vectors.float()

    # Evaluate the log-likelihood scores for all data points
    ll_score = torch_spn.evaluate(all_data_points)

    # Find the maximum value and its index
    best_ll, max_index = torch.max(ll_score, dim=0)

    # Get the corresponding query fill for the best log-likelihood score
    best_query_fill = binary_vectors[max_index]
    return best_query_fill.cpu().numpy(), best_ll.cpu().numpy()


def get_true_mmap_np(root_spn, buckets, example):
    """
    The function `get_true_mmap_np` takes a root SPN, buckets, and an example as input, and returns the
    best query fill and its corresponding log-likelihood score.

    :param root_spn: The `root_spn` parameter is the root node of the Sum-Product Network (SPN) model.
    It represents the top-level node that contains all the sub-nodes and defines the structure of the
    SPN
    :param buckets: The "buckets" parameter is a dictionary that contains information about the
    variables in the SPN (Sum-Product Network) model. It specifies which variables are query variables
    and which are evidence variables
    :param example: The `example` parameter is a numpy array representing a single data point. It
    contains the values of all variables in the SPN
    :return: the best query fill (a binary vector) and the corresponding log-likelihood score.
    """
    import numpy as np
    from deeprob.spn.algorithms.inference import log_likelihood

    query_mask = buckets["query"]
    num_possible_query = np.sum(query_mask)
    data = example

    # Generate binary vectors using NumPy
    binary_vectors = np.zeros((2**num_possible_query, num_possible_query))

    for i in range(2**num_possible_query):
        for j in range(num_possible_query):
            binary_vectors[i, j] = (i >> j) % 2

    # Repeat the data array to match the number of query fill combinations
    all_data_points = np.repeat(data[np.newaxis, :], binary_vectors.shape[0], axis=0)

    # Fill the query variables with all possible query fill combinations
    all_data_points[:, query_mask] = binary_vectors.astype(float)

    # Evaluate the log-likelihood scores for all data points
    ll_score = log_likelihood(
        root_spn,
        all_data_points,
        return_results=False,
    )
    all_data_points_2 = all_data_points.astype(int)
    ll_score_2 = log_likelihood(
        root_spn,
        all_data_points_2,
        return_results=False,
    )

    # Find the maximum value and its index
    max_index = np.argmax(ll_score)
    best_ll = ll_score[max_index]

    # Get the corresponding query fill for the best log-likelihood score
    best_query_fill = binary_vectors[max_index]

    return best_query_fill, best_ll


def mmap_for_dataset(cfg, root_spn, buckets, dataset, device="cuda"):
    """
    The function `mmap_for_dataset` takes in arguments, a root_spn, buckets, dataset, and device, and
    returns labels and ll_scores.

    :param cfg: The `cfg` parameter is a variable that contains additional arguments or settings for
    the function. It is likely defined elsewhere in the code and passed as an argument to this function
    :param root_spn: The root_spn parameter is the root node of the Sum-Product Network (SPN) model. It
    represents the top-level node of the SPN, from which all other nodes can be reached
    :param buckets: The "buckets" parameter is a list that contains the buckets used in the SPN
    (Sum-Product Network) model. Each bucket represents a node in the SPN and contains the parameters
    (weights and biases) for that node. The buckets are used to compute the probabilities and
    likelihoods in the
    :param dataset: The `dataset` parameter is a numpy array containing the data points for which we
    want to compute the maximum marginal probability (MMP) and the corresponding query fill. Each row of
    the array represents a data point
    :param device: The `device` parameter specifies the device on which the computation will be
    performed. In this case, the default value is "cuda", which indicates that the computation will be
    performed on a CUDA-enabled GPU, defaults to cuda (optional)
    :return: two numpy arrays: `labels` and `ll_scores`.
    """
    labels = []
    ll_scores = []
    for i in tqdm(range(dataset.shape[0])):
        if cfg.debug and i == 10:
            break
        best_query_fill, best_ll = get_true_mmap_np(root_spn, buckets, dataset[i])
        labels.append(best_query_fill)
        ll_scores.append(best_ll)
    labels = np.stack(labels)
    ll_scores = np.stack(ll_scores)
    return labels, ll_scores


def get_spn_mpe_output(library_spn, outputs_for_spn, query_bucket):
    array_for_spn = outputs_for_spn
    array_for_spn[query_bucket] = np.nan
    mpe_output = mpe(library_spn, array_for_spn)
    return mpe_output
