import math
import time

import numpy as np
import torch
from deeprob.spn.algorithms.inference import log_likelihood, mpe
from deeprob.spn.algorithms.structure import marginalize
from loguru import logger
from tqdm import tqdm


def random_initialization(
    data, query_var_bool, unobs_var_bool, num_query_vars, num_unobs_vars
):
    """
    The function `random_initialization` generates random values for query and unobserved variables and
    replaces the corresponding columns in the input data.

    :param data: The data parameter is a numpy array that represents the dataset. It has shape (n, m),
    where n is the number of data points and m is the number of variables/features in the dataset
    :param query_var_bool: A boolean array indicating which variables in the data are query variables.
    True indicates that the variable is a query variable, and False indicates that it is not
    :param unobs_var_bool: The parameter `unobs_var_bool` is a boolean array that indicates which
    variables in the dataset are unobserved. It has the same length as the number of variables in the
    dataset, and each element is `True` if the corresponding variable is unobserved and `False`
    otherwise
    :param num_query_vars: The number of query variables, which are the variables whose values we want
    to randomly initialize
    :param num_unobs_vars: The parameter "num_unobs_vars" represents the number of unobserved variables
    in the dataset
    :return: a new array of data points with random values assigned to the query variables and
    unobserved variables.
    """

    random_values_for_query_and_unobs = np.random.randint(
        0, 2, size=(data.shape[0], num_query_vars + num_unobs_vars)
    )
    values_for_query_and_unobs = random_values_for_query_and_unobs
    new_data_points = data.copy()
    new_data_points[query_var_bool] = values_for_query_and_unobs[:, :num_query_vars]
    new_data_points[unobs_var_bool] = values_for_query_and_unobs[:, num_query_vars:]
    return new_data_points


def mpe_spn(prev_outputs, query_var_bool, evid_var_bool, library_spn_marginalized):
    """
    The function `mpe_spn` takes in previous outputs, query and evidence variables, and a library SPN,
    and returns the most probable explanation (MPE) output.

    :param prev_outputs: The `prev_outputs` parameter is a dictionary that contains the previous outputs
    of the function. It likely contains information that is needed for the current computation
    :param query_var_bool: The parameter `query_var_bool` is a boolean array indicating which variables
    are query variables. It has the same shape as `prev_outputs["all_outputs_for_spn"]`
    :param evid_var_bool: The `evid_var_bool` parameter is a boolean array that indicates which
    variables are observed evidence variables. The array has the same length as the number of variables
    in the SPN. A value of `True` at index `i` indicates that variable `i` is an observed evidence
    variable,
    :param library_spn_marginalized: The parameter `library_spn_marginalized` is a marginalized
    Sum-Product Network (SPN) model. It represents a probabilistic graphical model that has been trained
    and marginalized over some variables. The `mpe_spn` function takes this model as input and uses it
    to perform Maximum
    :return: the MPE (Most Probable Explanation) output.
    """
    array_for_spn = np.zeros_like(prev_outputs["all_outputs_for_spn"])
    array_for_spn[:] = -1
    query_bucket = query_var_bool
    evid_bucket = evid_var_bool
    array_for_spn[query_bucket] = np.nan
    all_unprocessed_data = np.array(prev_outputs["all_unprocessed_data"])
    array_for_spn[evid_bucket] = all_unprocessed_data[evid_bucket]
    logger.info(f"Array for spn {array_for_spn.shape}")
    mpe_output = mpe(library_spn_marginalized, array_for_spn)
    return mpe_output


def mle_spn(
    prev_outputs,
    query_var_bool,
    evid_var_bool,
    library_spn_marginalized,
    max_time_per_sample=0,
):
    """
    The function `mle_spn` performs maximum likelihood estimation on a sum-product network (SPN) given
    previous outputs, query and evidence variables, and a marginalized SPN.

    :param prev_outputs: The `prev_outputs` parameter is a dictionary that contains the previous outputs
    of the SPN model. It includes the following keys:
    :param query_var_bool: A boolean array indicating which variables are query variables. True
    indicates that the variable is a query variable, and False indicates that it is not
    :param evid_var_bool: The `evid_var_bool` parameter is a boolean array that indicates which
    variables are observed evidence variables. The length of the array should be equal to the number of
    variables in the SPN. A value of `True` at index `i` indicates that variable `i` is an observed
    evidence
    :param library_spn_marginalized: The `library_spn_marginalized` parameter is a pre-trained
    Sum-Product Network (SPN) model that has been marginalized. It represents the learned probability
    distribution over the variables in the model
    :return: the final outputs, which is a numpy array containing the maximum likelihood estimates (MLE)
    for the query variables given the evidence variables.
    """
    # query_var_idx = np.where(query_var_bool)[0]
    # evidence_var_idx = np.where(evid_var_bool)[0]
    final_outputs = np.zeros_like(prev_outputs["all_outputs_for_spn"])
    final_outputs[:] = -1
    final_outputs[evid_var_bool] = prev_outputs["all_unprocessed_data"][evid_var_bool]
    for ex_idx in tqdm(range(prev_outputs["all_outputs_for_spn"].shape[0])):
        # stop processing if more than args.max_time_per_sample seconds have elapsed

        query_var_idx = np.where(query_var_bool[ex_idx])[0]
        evidence_var_idx = np.where(evid_var_bool[ex_idx])[0]
        evidence_value = prev_outputs["all_unprocessed_data"][ex_idx][evidence_var_idx]
        for var_idx in query_var_idx:
            array_for_spn = np.zeros_like(
                prev_outputs["all_outputs_for_spn"][ex_idx]
            ).reshape(1, -1)
            array_for_spn[:] = -1
            array_for_spn[:, var_idx] = np.nan
            array_for_spn[:, evidence_var_idx] = evidence_value
            # keep the scope as evidence + this query variable
            scope_to_keep = [var_idx]
            scope_to_keep.extend(evidence_var_idx.tolist())
            spn_mle_marginalized = marginalize(
                library_spn_marginalized, keep_scope=scope_to_keep
            )
            mpe_output = mpe(spn_mle_marginalized, array_for_spn)
            final_outputs[ex_idx, var_idx] = mpe_output[:, var_idx]
    return final_outputs


def mle_spn_torch(
    prev_outputs,
    query_var_bool,
    evid_var_bool,
    torch_spn,
    device="cuda" if torch.cuda.is_available() else "cpu",
    max_time=60,
):
    start_time = time.time()
    # Convert and move data to the specified device
    query_var_bool = torch.as_tensor(query_var_bool, dtype=torch.bool, device=device)
    evid_var_bool = torch.as_tensor(evid_var_bool, dtype=torch.bool, device=device)
    all_unprocessed_data = torch.as_tensor(
        prev_outputs["all_unprocessed_data"], device=device
    )

    final_outputs = torch.full_like(all_unprocessed_data, torch.nan, device=device)
    final_outputs[evid_var_bool] = all_unprocessed_data[evid_var_bool]
    array_for_spn = final_outputs.clone()  # Move cloning outside the loop

    query_vars_each_example = torch.where(query_var_bool)[1].view(
        query_var_bool.size(0), -1
    )

    num_examples, num_query_vars = query_vars_each_example.size()
    row_indices = torch.arange(num_examples, device=device)
    best_ll_scores = torch.full(
        (num_examples,), -1e10, dtype=torch.float, device=device
    )
    num_examples = num_examples
    for idx in tqdm(range(num_query_vars)):
        query_var_this_iter = query_vars_each_example[:, idx]
        array_for_spn[query_var_bool] = torch.nan
        # Prepare arrays for val = 0 and val = 1
        array_for_spn_0 = array_for_spn.clone()
        array_for_spn_1 = array_for_spn.clone()
        array_for_spn_0[row_indices, query_var_this_iter] = 0
        array_for_spn_1[row_indices, query_var_this_iter] = 1

        # Stack rows of both arrays
        stacked_array_for_spn = torch.cat([array_for_spn_0, array_for_spn_1], dim=0)

        with torch.no_grad():
            # Process with stacked arrays
            ll_scores_stacked = torch_spn.evaluate(stacked_array_for_spn.float())

        # Split scores back into ll_scores_0 and ll_scores_1
        ll_scores_0, ll_scores_1 = torch.split(ll_scores_stacked, num_examples)

        # Determine the better score
        better_for_0 = ll_scores_0 >= ll_scores_1
        # better_for_1 will be all the
        better_for_1 = ll_scores_1 > ll_scores_0

        # Update final outputs and best likelihood scores
        final_outputs[row_indices[better_for_0], query_var_this_iter[better_for_0]] = 0
        final_outputs[row_indices[better_for_1], query_var_this_iter[better_for_1]] = 1
        best_ll_scores = torch.where(
            better_for_1,
            ll_scores_1,
            ll_scores_0,
        )
        # end the loop if more than max_time has elapsed
        elapsed_time = time.time() - start_time

        if elapsed_time and elapsed_time > max_time:
            print("Time limit exceeded, breaking out of loop")
            # update final data based on final outputs if the value is nan
            final_outputs = final_outputs.cpu().numpy()
            all_unprocessed_data = all_unprocessed_data.cpu().numpy()
            nan_indices = np.where(np.isnan(final_outputs))
            final_outputs[nan_indices] = all_unprocessed_data[nan_indices]
            return final_outputs
    return final_outputs.cpu().numpy()


def sequential_spn(
    prev_outputs, query_var_bool, evid_var_bool, library_spn_marginalized
):
    """
    The function `sequential_spn` performs sequential inference using a sum-product network (SPN) to
    compute the most probable explanation (MPE) for a set of query variables given evidence variables.

    :param prev_outputs: The `prev_outputs` parameter is a dictionary that contains the previous outputs
    of the SPN (Sum-Product Network) model. It includes the following keys:
    :param query_var_bool: The `query_var_bool` parameter is a boolean array indicating which variables
    are query variables. A value of `True` at index `i` means that variable `i` is a query variable
    :param evid_var_bool: The `evid_var_bool` parameter is a boolean array that indicates which
    variables are evidence variables. The length of the array should be equal to the number of variables
    in the SPN. A value of `True` at index `i` indicates that variable `i` is an evidence variable,
    :param library_spn_marginalized: The `library_spn_marginalized` parameter is a pre-trained
    Sum-Product Network (SPN) model that has been marginalized. It represents the learned probability
    distribution over the variables in your dataset. The SPN model is typically trained using a maximum
    likelihood estimation (MLE) approach
    :return: The function `sequential_spn` returns the final outputs, which is a numpy array containing
    the inferred values for the query variables.
    """
    final_outputs = np.zeros_like(prev_outputs["all_outputs_for_spn"])
    final_outputs[:] = -1
    num_examples = prev_outputs["all_outputs_for_spn"].shape[0]
    evidence_values = prev_outputs["all_unprocessed_data"][evid_var_bool]
    final_outputs[evid_var_bool] = evidence_values
    num_features = prev_outputs["all_outputs_for_spn"].shape[1]
    for ex_idx in tqdm(range(num_examples)):
        this_example = final_outputs[ex_idx].reshape(1, -1)
        processed_queries = np.zeros((num_features)).astype(bool)
        query_var_idx = np.where(query_var_bool[ex_idx])[0]
        evidence_var_idx = np.where(evid_var_bool[ex_idx])[0].tolist()
        for q_idx_1 in range(len(query_var_idx)):
            scores = np.ones((num_features)) * -1e7
            values = np.ones((num_features)) * -1e7
            current_processed_queries = np.where(processed_queries)[0]
            scope_denom = current_processed_queries.tolist() + evidence_var_idx
            # remove the processed queries from the query_var_idx
            remaining_query_var_idx = np.setdiff1d(
                query_var_idx, current_processed_queries
            )
            for q_idx_2 in remaining_query_var_idx:
                # get indices where processed_queries is True
                tmp_example = this_example.copy()
                ll_scores_this_query = []
                spn_mle_marginalized = marginalize(
                    library_spn_marginalized, keep_scope=(scope_denom + [q_idx_2])
                )
                for var_val in [0, 1]:
                    tmp_example[0, q_idx_2] = var_val
                    ll_score = log_likelihood(spn_mle_marginalized, tmp_example)
                    ll_scores_this_query.append(ll_score)
                ll_scores_this_query = np.array(ll_scores_this_query)
                # get the best ll score and the value that gave it
                best_ll_score = np.max(ll_scores_this_query)
                best_val = np.argmax(ll_scores_this_query)
                # update the scores and values
                scores[q_idx_2] = best_ll_score
                values[q_idx_2] = best_val
            # Check which query variable has the highest score
            best_query_var_idx = np.argmax(scores)
            assert best_query_var_idx in remaining_query_var_idx, "Something is wrong!"
            # Update the final outputs
            final_outputs[ex_idx, best_query_var_idx] = values[best_query_var_idx]
            this_example[0, best_query_var_idx] = values[best_query_var_idx]
            processed_queries[best_query_var_idx] = True
    return final_outputs


import torch
from tqdm import tqdm


def sequential_spn_torch(
    prev_outputs,
    query_var_bool,
    evid_var_bool,
    torch_spn,
    device=(
        "cuda" if torch.cuda.is_available() else "cpu"
    ),  # Automatically use GPU if available
    max_time=60,
):
    # Convert and move data to the specified device
    start_time = time.time()
    query_var_bool = torch.as_tensor(query_var_bool, dtype=torch.bool, device=device)
    evid_var_bool = torch.as_tensor(evid_var_bool, dtype=torch.bool, device=device)
    all_unprocessed_data = torch.as_tensor(
        prev_outputs["all_unprocessed_data"], device=device
    )

    final_outputs = torch.full_like(all_unprocessed_data, torch.nan, device=device)
    final_outputs[evid_var_bool] = all_unprocessed_data[evid_var_bool]
    num_features = final_outputs.size(1)

    query_vars_each_example = torch.where(query_var_bool)[1].view(
        query_var_bool.size(0), -1
    )
    num_examples, num_query_vars = query_vars_each_example.size()
    row_indices = torch.arange(num_examples, device=device)
    best_ll_scores = torch.full(
        (num_examples, num_features), -1e10, dtype=torch.float, device=device
    )
    for idx_1 in tqdm(range(num_query_vars)):
        array_for_spn = final_outputs.clone()  # Move cloning outside the loop
        this_iter_outputs = final_outputs.clone()
        best_ll_scores.fill_(-1e10)
        query_vars_each_example = torch.where(query_var_bool)[1].view(
            query_var_bool.size(0), -1
        )
        num_examples, num_query_vars = query_vars_each_example.size()
        # print(f"Processing {num_query_vars} num_query_vars")
        for idx_2 in range(num_query_vars):
            query_var_this_iter = query_vars_each_example[:, idx_2]
            # Prepare arrays for val = 0 and val = 1
            array_for_spn_0 = array_for_spn.clone()
            array_for_spn_1 = array_for_spn.clone()
            array_for_spn_0[row_indices, query_var_this_iter] = 0
            array_for_spn_1[row_indices, query_var_this_iter] = 1

            # Stack rows of both arrays
            stacked_array_for_spn = torch.cat([array_for_spn_0, array_for_spn_1], dim=0)

            with torch.no_grad():
                # Process with stacked arrays
                ll_scores_stacked = torch_spn.evaluate(stacked_array_for_spn.float())

            # Split scores back into ll_scores_0 and ll_scores_1
            ll_scores_0, ll_scores_1 = torch.split(ll_scores_stacked, num_examples)

            # Determine the better score
            better_for_0 = (
                ll_scores_0 > best_ll_scores[row_indices, query_var_this_iter]
            )
            best_ll_scores[
                row_indices[better_for_0], query_var_this_iter[better_for_0]
            ] = ll_scores_0[better_for_0]
            better_for_1 = (
                ll_scores_1 > best_ll_scores[row_indices, query_var_this_iter]
            )
            best_ll_scores[
                row_indices[better_for_1], query_var_this_iter[better_for_1]
            ] = ll_scores_1[better_for_1]
            # Update final outputs and best likelihood scores
            this_iter_outputs[
                row_indices[better_for_0], query_var_this_iter[better_for_0]
            ] = 0
            this_iter_outputs[
                row_indices[better_for_1], query_var_this_iter[better_for_1]
            ] = 1
            # update ll scores for for all examples and query_var_this_iter based on higher scores for 0 or 1 where best_ll_scores is (num_examples, num_features)

        # for each examlple, find the query var with the highest score
        best_query_vars = torch.argmax(best_ll_scores, dim=1)
        # update the final outputs
        final_outputs[row_indices, best_query_vars] = this_iter_outputs[
            row_indices, best_query_vars
        ]
        query_var_bool[row_indices, best_query_vars] = False
        elapsed_time = time.time() - start_time
        # end the loop if more than max_time has elapsed
        if elapsed_time and elapsed_time > max_time:
            print("Time limit exceeded, breaking out of loop")
            # update final data based on final outputs if the value is nan
            final_outputs = final_outputs.cpu().numpy()
            all_unprocessed_data = all_unprocessed_data.cpu().numpy()
            nan_indices = np.where(np.isnan(final_outputs))
            final_outputs[nan_indices] = all_unprocessed_data[nan_indices]
            return final_outputs

    return final_outputs.cpu().numpy()  # Move the result back to CPU if needed
