import copy
from math import log

import numpy as np
import torch
import wandb
from data.dataloaders import (
    create_buckets_all_queries_per,
    create_buckets_one_query_per,
)
from deeprob.spn.algorithms.inference import log_likelihood, mpe
from deeprob.spn.algorithms.structure import marginalize
from get_model import init_model_and_optimizer
from loguru import logger
from torch import nn
from tqdm import tqdm


def get_learnable_embeddings(cfg, model, batch_data):
    """
    The function `generate_embeddings` takes a model and batch data as input and generates embeddings for the batch.

    :param model: The model is the neural network model that contains the embedding layer.
    :param batch_data: The batch data is a dictionary that contains the input data, query, unobserved, and evidence tensors.
    :return: The generated embeddings for the batch.
    """
    input_tensor = batch_data["data"].clone()
    batch_query_tensors = batch_data["query"]
    batch_unobs_tensors = batch_data["unobs"]
    batch_evid_tensors = batch_data["evid"]

    # Vectorized computation of index tensor
    ip_for_embedding = torch.zeros_like(input_tensor, dtype=torch.long)
    ip_for_embedding += batch_unobs_tensors.long() * 0  # Index 0 for 'unobserved'
    ip_for_embedding += batch_evid_tensors.long() * (
        1 + input_tensor.long()
    )  # 1 for 'evid=0', 2 for 'evid=1'
    ip_for_embedding += batch_query_tensors.long() * 3  # Index 3 for 'query'

    # Generate embeddings for the batch
    data_input = model.embedding_layer(ip_for_embedding)
    if cfg.model != "transformer":
        data_input = data_input.view(
            -1, model.ip_embedding_size
        )  # Flatten the embeddings

    return data_input


def get_sin_embeddings(cfg, model, batch_data):
    input_tensor = batch_data["data"].clone()
    return model.embedding_layer(input_tensor)


def get_buckets_for_iter(cfg, num_var_in_buckets, batch_data):
    sample_size, num_vars = batch_data["initial"].shape
    if cfg.use_single_model:
        single_example_bucket = create_buckets_all_queries_per(
            sample_size, cfg.task, batch_data["initial"].device
        )
    else:
        single_example_bucket = create_buckets_one_query_per(
            num_vars,
            num_var_in_buckets,
            batch_data["initial"].device,
        )

    # check if the bucket is empty - if so, then create a bucket with all False values
    for bucket_name in ["evid", "query", "unobs"]:
        single_example_bucket.setdefault(
            bucket_name, torch.zeros(sample_size, dtype=torch.bool)
        )
        # duplicate each value in the dict single_example_bucket to create a dict with batch_data["initial"].shape[0] duplicates for each key in torch
    single_example_bucket = {
        key: torch.stack([single_example_bucket[key]] * batch_data["initial"].shape[0])
        for key in single_example_bucket.keys()
    }

    return single_example_bucket


def pre_process_data(cfg, model, spn, num_var_in_buckets, batch_data, train=True):
    get_ip_from_spn(cfg, spn, batch_data)
    # Check if cfg.same_bucket_epoch is true - if so, then use the same bucket for all the data points in the epoch
    if train:
        if cfg.same_bucket_iter:
            # create same bucket for all the data points in the epoch
            single_example_bucket = get_buckets_for_iter(
                cfg, num_var_in_buckets, batch_data
            )
            # Update the batch_data with the new bucket
            batch_data["evid"] = single_example_bucket["evid"]
            batch_data["query"] = single_example_bucket["query"]
            batch_data["unobs"] = single_example_bucket["unobs"]
    else:
        buckets = {}
        for each_key in ["evid", "query", "unobs"]:
            buckets[each_key] = batch_data[each_key]
    get_embeddings(cfg, model, batch_data)
    data_pack = (
        batch_data["data"],
        batch_data["dataSpn"],
        batch_data["initial"],
        batch_data["evid"],
        batch_data["query"],
        batch_data["unobs"],
        batch_data["attention_mask"],
    )

    return data_pack


def get_ip_from_spn(cfg, spn, batch_data):
    if "spn" in cfg.input_type.lower():
        samples_for_spn = batch_data["initial"].clone()
        batch_data["dataSpn"] = spn.get_input_features(
            samples_for_spn, batch_data["query"], batch_data["unobs"]
        )
    else:
        batch_data["dataSpn"] = None


def get_embeddings(cfg, model, batch_data):
    if cfg.embedding_type == "continuousEmbed":
        # We use a learnable embedding layer to embed the data
        data_input = get_learnable_embeddings(cfg, model, batch_data)
        batch_data["data"] = data_input
    elif cfg.embedding_type == "continuousSin":
        data_input = get_sin_embeddings(cfg, model, batch_data)
        batch_data["data"] = data_input


def train(
    cfg, model, spn, device, fabric, train_loader, optimizer, epoch, schedular, **kwargs
):
    """
    The `train` function is used to train a model using a given dataset and optimizer, and logs the
    training loss.

    :param cfg: The `cfg` parameter is a dictionary or object that contains various arguments or
    configuration settings for the training process. It is used to pass information such as batch size,
    learning rate, number of epochs, etc
    :param model: The `model` parameter is the neural network model that you want to train. It should be
    an instance of a PyTorch model class
    :param spn: The "spn" parameter is likely referring to a Sum-Product Network (SPN) model. SPNs are a
    type of probabilistic graphical model that can be used for various tasks such as classification,
    regression, and anomaly detection. In this context, the SPN model is being used for
    :param device: The `device` parameter is used to specify whether the model should be trained on a
    CPU or a GPU. It is typically a string value, such as "cpu" or "cuda:0", where "cuda:0" refers to
    the first available GPU
    :param fabric: The `fabric` parameter is a module that contains the SPN (Sum-Product Network)
    architecture. It is used to create and manipulate the SPN model
    :param train_loader: The `train_loader` parameter is a PyTorch `DataLoader` object that is used to
    load the training data in batches. It is responsible for iterating over the training dataset and
    providing batches of data to the model during training
    :param optimizer: The optimizer is an object that implements the optimization algorithm. It is used
    to update the parameters of the model during training. Examples of optimizers include Stochastic
    Gradient Descent (SGD), Adam, and RMSprop
    :param epoch: The `epoch` parameter represents the current epoch number during training. An epoch is
    a complete pass through the entire training dataset
    :return: the average training loss for the epoch.
    """
    model.train()
    train_loss = 0
    num_var_in_buckets = train_loader.dataset.num_var_in_buckets
    for batch_idx, (batch_data) in enumerate(train_loader):
        data_pack = pre_process_data(cfg, model, spn, num_var_in_buckets, batch_data)
        optimizer.zero_grad()
        loss = model.train_iter(spn, *data_pack)
        # add gradient clipping if cfg.add_gradient_clipping is true
        if cfg.add_gradient_clipping:
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
        fabric.backward(loss)
        optimizer.step()
        final_loss = loss
        train_loss += final_loss.item()
        if schedular is not None and cfg.lr_scheduler == "OneCycleLR":
            # cyclic learning rate scheduler needs to be called after every batch
            schedular.step()

        # Efficient logging
        if batch_idx % cfg.log_interval == 0:
            current_loss = loss.item()
            logger.info(
                f"Train Epoch: {epoch} [{batch_idx * len(data_pack[0])}/{len(train_loader.dataset)} "
                f"({100.0 * batch_idx / len(train_loader):.0f}%)]\t Training Loss: {current_loss:.6f}"
            )

            if cfg.dry_run:
                break
    train_loss /= batch_idx + 1
    wandb.log({"Train": train_loss})
    logger.info("\Train set: Average loss: {:.4f}".format(train_loss))
    return train_loss


def train_dual_network(
    cfg,
    model,
    teacher_model,
    pgm,
    library_pgm,
    device,
    fabric,
    train_loader,
    optimizer,
    teacher_optimizer,
    database_for_student,
    bucket_database_for_student,
    best_loss_database,
    epoch,
    schedular,
):
    model.train()
    student_train_losses = 0
    teacher_train_losses = 0
    teacher_val_losses = 0
    num_var_in_buckets = train_loader.dataset.num_var_in_buckets

    for batch_idx, (batch_data) in enumerate(train_loader):
        if cfg.copy_student_to_teacher_dn:
            # copy the parameters of the student model to the teacher model
            teacher_model.load_state_dict(model.state_dict())
        else:
            teacher_model.initialize_weights()
        teacher_model.train()
        ex_idx = batch_data["index"]
        data_pack = pre_process_data(
            cfg, teacher_model, pgm, num_var_in_buckets, batch_data
        )
        # Initialize early stopping parameters
        loss_history = []
        convergence_iter = 0
        convergence_threshold = 1e-3  # Define your convergence threshold

        # Train the teacher model usi
        this_teacher_loss = 0
        for teacher_train_iter in range(cfg.tot_train_dn):
            teacher_optimizer.zero_grad()
            teacher_loss = teacher_model.train_iter(
                pgm,
                *data_pack,
            )
            if cfg.add_gradient_clipping:
                nn.utils.clip_grad_norm_(teacher_model.parameters(), cfg.grad_clip_norm)
            fabric.backward(teacher_loss)
            teacher_optimizer.step()
            this_teacher_loss += teacher_loss.item()
            # Early stopping check
            # Update loss history and check for convergence
            convergence_iter, converged = has_converged(
                loss_history,
                convergence_iter,
                teacher_loss.item(),
                convergence_threshold,
            )
            if converged:
                print("Convergence detected, stopping training")
                break
            else:
                pass
        # get outputs from the teacher model that will be used as pseudo labels for the student model
        this_teacher_loss /= cfg.tot_train_dn
        teacher_train_losses += this_teacher_loss
        with torch.no_grad():
            teacher_model.eval()
            (
                teacher_val_loss,
                output_for_pgm,
                initial_data,
                buckets,
            ) = teacher_model.supervised_validate_iter(
                pgm,
                *data_pack,
                return_mean=False,
            )
            teacher_val_loss_value = teacher_val_loss.mean().item()
            teacher_val_losses += teacher_val_loss_value
            # all_outputs_for_spn has the NN outputs and the evidence variables set to the input
            output_for_pgm = torch.tensor(output_for_pgm, device=device)
            output_for_pgm = torch.where(output_for_pgm > cfg.threshold, 1.0, 0.0)
            # Add the outputs to the database if loss is better for this batch
            assert (
                teacher_val_loss.shape == best_loss_database[ex_idx].shape
            ), "Shapes do not match for loss comparison"
            mask_for_loss = teacher_val_loss < best_loss_database[ex_idx]
            update_best_solutions_from_mask(
                database_for_student,
                bucket_database_for_student,
                best_loss_database,
                output_for_pgm,
                buckets,
                teacher_val_loss,
                mask_for_loss,
                ex_idx,
            )

            # get best solutions from spn
            get_best_sol_spn(
                library_pgm,
                database_for_student,
                bucket_database_for_student,
                best_loss_database,
                ex_idx,
                initial_data,
                output_for_pgm,
                buckets,
            )

        # x - data_pack, y
        model.train()
        psuedo_labels = database_for_student[ex_idx]
        this_batch_loss = 0
        for _ in range(cfg.student_train_iter_dn):
            optimizer.zero_grad()
            stud_loss = model.supervised_train_iter(pgm, psuedo_labels, *data_pack)
            # add gradient clipping if cfg.add_gradient_clipping is true
            if cfg.add_gradient_clipping:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
            fabric.backward(stud_loss)
            optimizer.step()
            final_loss = stud_loss
            this_batch_loss += final_loss.item()
        this_batch_loss /= cfg.student_train_iter_dn
        student_train_losses += this_batch_loss
        if schedular is not None and cfg.lr_scheduler == "OneCycleLR":
            # cyclic learning rate scheduler needs to be called after every batch
            schedular.step()

        # Efficient logging
        if batch_idx % cfg.log_interval == 0:
            logger.info(
                f"Train Epoch: {epoch} [{batch_idx * len(data_pack[0])}/{len(train_loader.dataset)} "
                f"({100.0 * batch_idx / len(train_loader):.0f}%)]\t Student Training Loss: {(this_batch_loss):.6f}, \t Teacher Training Loss: {this_teacher_loss:.6f}, \t Teacher Validation Loss: {teacher_val_loss_value:.6f}"
            )

            if cfg.dry_run:
                break
    student_train_losses /= batch_idx + 1
    teacher_train_losses /= batch_idx + 1
    teacher_val_losses /= batch_idx + 1
    wandb.log(
        {
            "Student Train": student_train_losses,
            "Teacher Train": teacher_train_losses,
            "Teacher Val": teacher_val_losses,
        }
    )
    logger.info(
        "\nTrain set: Average Student loss: {:.4f}, Teacher Training Loss: {:.4f}, Teacher Validation Loss: {:.4f}".format(
            student_train_losses, teacher_train_losses, teacher_val_losses
        )
    )
    return student_train_losses


def get_best_sol_spn(
    library_pgm,
    database_for_student,
    bucket_database_for_student,
    best_loss_database,
    ex_idx,
    all_unprocessed_data,
    all_outputs_for_spn,
    all_buckets,
):
    if library_pgm is not None:
        root_ll_spn, mpe_output = get_solutions_from_library_pgm(
            library_pgm, all_unprocessed_data, all_outputs_for_spn, all_buckets
        )
        # The loss is negative log likelihood, thus we want to keep the best solutions with min loss
        library_loss = -torch.tensor(root_ll_spn).to(
            device=database_for_student.device, dtype=torch.double
        )
        # Compare the computed library_loss with the best_loss_database for the given example index
        mask_for_library_pgm = library_loss < best_loss_database[ex_idx]
        mpe_output = torch.tensor(mpe_output, device=database_for_student.device)
        update_best_solutions_from_mask(
            database_for_student,
            bucket_database_for_student,
            best_loss_database,
            mpe_output,
            all_buckets,
            library_loss,
            mask_for_library_pgm,
            ex_idx,
        )


def update_best_solutions_from_mask(
    database_for_student,
    bucket_database_for_student,
    best_loss_database,
    all_outputs_for_spn,
    all_buckets,
    loss,
    mask_for_loss,
    example_indices,
):
    # Update best_loss_database
    selected_rows_best_loss = best_loss_database[example_indices]
    selected_rows_best_loss[mask_for_loss] = loss[mask_for_loss]
    best_loss_database[example_indices] = selected_rows_best_loss

    # Update database_for_student
    selected_rows_database = database_for_student[example_indices]
    selected_rows_database[mask_for_loss] = all_outputs_for_spn[mask_for_loss]
    database_for_student[example_indices] = selected_rows_database

    # Update bucket_database_for_student for each key
    for key in all_buckets:
        selected_rows_bucket = bucket_database_for_student[key][example_indices]
        selected_rows_bucket[mask_for_loss, :] = all_buckets[key][mask_for_loss, :]
        bucket_database_for_student[key][example_indices] = selected_rows_bucket


def get_solutions_from_library_pgm(
    library_pgm, all_unprocessed_data, all_outputs_for_spn, all_buckets
):
    # Ensure data is moved to CPU before converting to NumPy arrays
    # Use .cpu() method to transfer tensors to CPU before converting to numpy arrays
    array_for_spn = np.full(all_outputs_for_spn.shape, -1.0, dtype=np.float64)

    # Convert buckets to numpy arrays after moving to CPU
    query_bucket = all_buckets["query"].cpu().numpy()
    evid_bucket = all_buckets["evid"].cpu().numpy()

    # Set the specified buckets to NaN or values from all_unprocessed_data
    array_for_spn[query_bucket] = np.nan
    all_unprocessed_data = all_unprocessed_data.detach().cpu().numpy()
    array_for_spn[evid_bucket] = all_unprocessed_data[evid_bucket]

    # Call mpe and log_likelihood with the prepared NumPy array
    mpe_output = mpe(library_pgm, array_for_spn)
    root_ll_spn = log_likelihood(library_pgm, mpe_output, return_results=False)

    return root_ll_spn, mpe_output


@torch.no_grad()
def validate(cfg, model, spn, device, test_loader, best_loss, counter):
    """
    The `validate` function evaluates the performance of a model on a test dataset and returns the test
    loss, along with other outputs that can be used for further analysis.

    :param cfg: The `cfg` parameter is a dictionary or object that contains various arguments or
    configuration settings for the function. It is used to pass additional information or options to the
    function
    :param model: The `model` parameter is the neural network model that you want to validate. It should
    be an instance of a PyTorch model class
    :param spn: The `spn` parameter is the Sum-Product Network (SPN) model that is being used for
    validation. It is a probabilistic graphical model that can be used for various tasks such as
    classification, regression, and density estimation. The SPN model is typically trained using a
    neural network (
    :param device: The `device` parameter is used to specify the device (CPU or GPU) on which the
    computations should be performed. It is typically a `torch.device` object that represents the device
    :param test_loader: The `test_loader` is a data loader object that provides batches of test data. It
    is used to iterate over the test dataset during the validation process
    :param best_loss: The best loss is the lowest loss achieved during the validation process. It is
    used to track the best performance of the model
    :param counter: The `counter` parameter is used to keep track of the number of consecutive times the
    validation loss does not improve. It is typically used as a stopping criterion for training, where
    if the counter exceeds a certain threshold, training is stopped
    :return: the following values:
    """
    model.eval()
    test_loss = 0
    min_delta = (
        0.001  # Minimum change in the validation loss to be considered as improvement
    )
    all_unprocessed_data = []
    all_nn_outputs = []
    all_outputs_for_spn = []
    all_buckets = {"evid": [], "query": [], "unobs": []}
    num_var_in_buckets = None
    for batch_idx, (batch_data) in enumerate(test_loader):
        data_pack = pre_process_data(
            cfg, model, spn, num_var_in_buckets, batch_data, train=False
        )
        # Compute validation loss
        loss = model.validate_iter(
            spn,
            all_unprocessed_data,
            all_nn_outputs,
            all_outputs_for_spn,
            all_buckets,
            *data_pack,
        )
        test_loss += loss.item()
        if batch_idx % cfg.log_interval == 0:
            current_loss = loss.item()
            logger.info(
                f"Validation Epoch: [{batch_idx * len(data_pack[0])}/{len(test_loader.dataset)} "
                f"({100.0 * batch_idx / len(test_loader):.0f}%)]\t Validations Loss: {current_loss:.6f}"
            )

    # Calculate average test loss
    test_loss /= batch_idx + 1
    logger.info(f"\nTest set: Average loss: {test_loss:.4f}")
    wandb.log({"validation_loss": test_loss})
    # Early stopping check and best loss update
    if test_loss < best_loss - min_delta:
        best_loss = test_loss  # Update best loss if improved
        counter = 0  # Reset counter on improvement
    else:
        counter += 1  # Increment counter on no improvement
    # All output for SPN is the output of the NN after adding the values of the evidence - which can be used to calculate log liklihood of the spn
    return (
        best_loss,
        test_loss,
        counter,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_spn,
        all_buckets,
    )


def prepare_data_pack_standard(data_packs):
    return [each.unsqueeze(1) if each is not None else None for each in data_packs]


def prepare_data_pack_with_duplication(cfg, data_packs, idx):
    batch_size = min(cfg.test_batch_size, 512) if cfg.test_batch_size < 512 else 128
    return [
        each[idx].unsqueeze(0).repeat(batch_size, 1) if each is not None else None
        for each in data_packs
    ]


def update_loss_history(loss_history, loss, convergence_threshold):
    loss_history.append(loss)
    if len(loss_history) > 5:
        # Remove the oldest loss value to keep the list size constant
        loss_history.pop(0)
        return all(
            abs(loss - loss_history[0]) < convergence_threshold for loss in loss_history
        )
    return False


def has_converged(
    loss_history,
    convergence_iter,
    loss,
    convergence_threshold,
    early_stopping_patience=5,
):
    if update_loss_history(loss_history, loss, convergence_threshold):
        convergence_iter += 1
        return convergence_iter, convergence_iter >= early_stopping_patience
    else:
        return 0, False


def perturb_weights(model, perturbation_level=0.05):
    """
    Perturbs the weights of the model slightly to help escape local minima.
    """
    with torch.no_grad():
        for param in model.parameters():
            noise = torch.randn(param.size()) * perturbation_level
            param.add_(noise)


def train_and_validate_single_example(
    cfg,
    model,
    spn,
    fabric,
    loader,
    optimizer,
    best_loss,
    counter,
    library_spn,
    device,
    num_data_features,
    num_spn_feature,
    num_outputs,
    num_query_variables,
):
    """
    The `train_and_validate` function is used to train and validate a model on a given dataset. It
    returns the best loss, along with other outputs that can be used for further analysis.
    """
    all_unprocessed_data, all_nn_outputs, all_outputs_for_spn = [], [], []
    all_buckets = {"evid": [], "query": [], "unobs": []}
    test_loss = 0
    # Save the initial model state outside the loop to avoid repeated deep copying
    model_state = copy.deepcopy(model.state_dict())
    num_var_in_buckets = loader.dataset.num_var_in_buckets
    dataset = loader.dataset[:]
    if cfg.same_bucket_iter:
        cfg.same_bucket_iter = False
        logger.info("Same bucket iter is not allowed for train and validate")
    data_packs = pre_process_data(cfg, model, spn, num_var_in_buckets, dataset)
    if not cfg.duplicate_example_train_on_test:
        data_packs = prepare_data_pack_standard(data_packs)

    if not cfg.only_test_train_on_test_set:
        assert (
            cfg.num_init_train_on_test == 1
        ), "When a model is trained, we should not take multiple initializations of the NN"
    models_and_optimizers = init_model_and_optimizer(
        cfg,
        library_spn,
        device,
        fabric,
        num_data_features,
        num_spn_feature,
        num_outputs,
        num_query_variables=num_query_variables,
        run_type="test",
    )
    if cfg.dual_network:
        # student model is used for testing
        model, optimizer, _, _ = models_and_optimizers
    else:
        model, optimizer = models_and_optimizers

    convergence_threshold, perturbation_level = 1e-3, 0.05

    # go over each example in the dataset
    for idx in tqdm(range(len(data_packs[0]))):
        # Initialize the optimizer for each new example
        # do one example at a time but in parallel
        if cfg.train_on_test_set_scheduler == "StepLR":
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=cfg.num_iter_train_on_test // 5, gamma=0.8
            )
        model.load_state_dict(model_state)
        model.train()
        if cfg.duplicate_example_train_on_test:
            # make every tensor 2d
            data_pack = prepare_data_pack_with_duplication(cfg, data_packs, idx)
        else:
            data_pack = [each[idx] if each is not None else None for each in data_packs]
        loss_history = []
        convergence_iter = 0
        for num_init in range(cfg.num_init_train_on_test):
            if cfg.only_test_train_on_test_set:
                # Initialize the weights of the model in for each new initialization
                model.initialize_weights()
            # train the model on the test set for a few iterations
            for iter in range(cfg.num_iter_train_on_test):
                # train the model on the test set for a few iterations
                optimizer.zero_grad()
                loss = model.train_iter(spn, *data_pack)
                fabric.backward(loss)
                optimizer.step()
                if cfg.train_on_test_set_scheduler == "StepLR":
                    scheduler.step()
                convergence_iter, converged = has_converged(
                    loss_history,
                    convergence_iter,
                    loss.item(),
                    convergence_threshold,
                    cfg.early_stopping_patience,
                )
                if converged:
                    print("Convergence detected, stopping training")
                    break
                else:
                    pass
                if cfg.debug and iter == cfg.num_iter_train_on_test - 1:
                    logger.info(f"Example {idx}, Iter {iter}, Loss {loss.item()}")
        with torch.no_grad():
            model.eval()
            loss = model.validate_iter(
                spn,
                all_unprocessed_data,
                all_nn_outputs,
                all_outputs_for_spn,
                all_buckets,
                *data_pack,
            )
            test_loss += loss.item()
    # Average test loss calculation
    test_loss /= idx + 1
    logger.info(f"\nTest set: Average loss: {test_loss:.4f}")
    wandb.log({"test_loss": test_loss})
    # All output for SPN is the output of the NN after adding the values of the evidence - which can be used to calculate log liklihood of the spn
    return (
        best_loss,
        test_loss,
        counter,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_spn,
        all_buckets,
    )


# def train_and_validate_batch(
#     cfg,
#     model,
#     spn,
#     fabric,
#     loader,
#     optimizer,
#     best_loss,
#     counter,

# ):
#     """
#     The `train_and_validate` function is used to train and validate a model on a given dataset. It
#     returns the best loss, along with other outputs that can be used for further analysis.
#     """
#     all_unprocessed_data = []
#     all_nn_outputs = []
#     all_outputs_for_spn = []
#     all_buckets = {"evid": [], "query": [], "unobs": []}
#     test_loss = 0
#     # Save the initial model state outside the loop to avoid repeated deep copying
#     model_state = copy.deepcopy(model.state_dict())
#     assert (
#         cfg.same_bucket_iter is False
#     ), "Same bucket iter is not allowed for train and validate"
#     num_var_in_buckets = loader.dataset.num_var_in_buckets
#     for batch_idx, (batch_data) in enumerate(tqdm(loader)):
#         # Use the same bucket for all the data points in the batch and copy and stack the bucket for each data point
#         optimizer = select_optimizer(model, cfg.test_optimizer, cfg.test_lr, cfg.test_weight_decay)
#         init_bucket = {key: batch_data[key] for key in ["evid", "query", "unobs"]}
#         this_batch_bucket = {
#             key: init_bucket[key][0]
#             .clone()
#             .unsqueeze(0)
#             .repeat(batch_data["data"].shape[0], 1)
#             for key in init_bucket
#         }
#         for key in ["evid", "query", "unobs"]:
#             batch_data[key] = this_batch_bucket[key]
#         data_pack = pre_process_data(cfg, model, spn, num_var_in_buckets, batch_data)
#         model.load_state_dict(model_state)
#         model.train()
#         # make every tensor 2d
#         loss_history = []
#         convergence_threshold = 1e-4
#         convergence_iter = 0
#         # train the model on the test set for a few iterations
#         for iter in range(cfg.num_iter_train_on_test):
#             # train the model on the test set for a few iterations
#             optimizer.zero_grad()
#             loss = model.train_iter(spn, *data_pack)
#             fabric.backward(loss)
#             optimizer.step()
#             loss_history.append(loss.item())
#             if len(loss_history) > 20:
#                 if all(
#                     abs(loss - loss_history[-20]) < convergence_threshold
#                     for loss in loss_history[-20:]
#                 ):
#                     convergence_iter += 1
#                 else:
#                     convergence_iter = 0

#                 if convergence_iter >= 20:
#                     break
#         model.eval()
#         with torch.no_grad():
#             loss = model.validate_iter(
#                 spn,
#                 all_unprocessed_data,
#                 all_nn_outputs,
#                 all_outputs_for_spn,
#                 all_buckets,
#                 *data_pack,
#             )
#             test_loss += loss.item()
#     # Average test loss calculation
#     test_loss /= batch_idx + 1
#     logger.info(f"\nTest set: Average loss: {test_loss:.4f}")
#     wandb.log({"test_loss": test_loss})
#     # All output for SPN is the output of the NN after adding the values of the evidence - which can be used to calculate log liklihood of the spn
#     return (
#         best_loss,
#         test_loss,
#         counter,
#         all_unprocessed_data,
#         all_nn_outputs,
#         all_outputs_for_spn,
#         all_buckets,
#     )


def threshold_array(arr, threshold, value_less, value_more):
    """
    The function `threshold_array` takes an array `arr`, a threshold value, and two values `value_less`
    and `value_more`, and returns a new array where values less than or equal to the threshold are
    replaced with `value_less` and values greater than the threshold are replaced with `value_more`.

    :param arr: The input array on which the threshold operation will be performed
    :param threshold: The threshold is a value that determines whether elements in the array are
    considered less than or equal to the threshold or greater than the threshold
    :param value_less: The value to assign to elements in the array that are less than or equal to the
    threshold
    :param value_more: The value to assign to elements in the array that are greater than the threshold
    :return: a new array where the elements that are less than or equal to the threshold are replaced
    with the value_less, and the elements that are greater than the threshold are replaced with the
    value_more.
    """
    return np.where(arr <= threshold, value_less, value_more)


def remove_nan_rows(arr):
    """
    The function removes rows containing NaN values from a given array.

    :param arr: The parameter "arr" is expected to be a numpy array
    :return: an array with the rows that do not contain any NaN values.
    """
    return arr[~np.isnan(arr).any(axis=1)]


def evaluate_nn(cfg, spn, pytorch_spn, all_outputs_for_spn):
    """
    The function `evaluate_nn` takes in arguments, an SPN, a PyTorch SPN, and an array of outputs for
    the SPN, and returns the log likelihood of the root node for both the original SPN and the PyTorch
    SPN implementation.

    :param cfg: The `cfg` parameter is likely an object or dictionary that contains various
    configuration settings or arguments for the function. It is used to pass in values such as the
    threshold, device, and debug flag
    :param spn: The `spn` parameter is the original SPN model that you want to evaluate
    :param pytorch_spn: The `pytorch_spn` parameter is a PyTorch implementation of the SPN (Sum-Product
    Network) model. It is used to evaluate the log-likelihood of the SPN model on the given input data
    :param all_outputs_for_spn: The `all_outputs_for_spn` parameter is a numpy array that contains the
    outputs of the SPN (Sum-Product Network) model for a given set of input data. It is used as input to
    calculate the log-likelihood of the SPN model
    :return: The function `evaluate_nn` returns two values: `root_ll` and `root_ll_our_spn`.
    """
    # root_ll, lls = log_likelihood(spn_root, train_data[:10], return_results=True)
    all_outputs_for_spn = np.array(all_outputs_for_spn)
    # all_outputs_for_spn = remove_nan_rows(all_outputs_for_spn)
    all_outputs_for_spn = threshold_array(all_outputs_for_spn, cfg.threshold, 0, 1)
    if cfg.pgm == "spn":
        root_ll = log_likelihood(
            spn,
            all_outputs_for_spn,
            return_results=False,
        )
    else:
        root_ll = np.zeros(all_outputs_for_spn.shape[0])
    # if cfg.debug:
    # root_ll = np.average(root_ll)
    # logger.info(f"From Library {root_ll}")
    # Check if our spn implementation is correct
    all_outputs_for_spn = (
        torch.FloatTensor(all_outputs_for_spn).to(cfg.device).to(torch.float64)
    )
    root_ll_our_spn = pytorch_spn.evaluate(all_outputs_for_spn)
    root_ll_our_spn = root_ll_our_spn.detach().cpu().numpy()
    # logger.info(f"Our Implementation {torch.mean(root_ll_our_spn).item()}")

    return root_ll, root_ll_our_spn


def get_ll_scores(
    cfg,
    root_spn,
    torch_spn,
    all_unprocessed_data,
    all_outputs_for_spn,
    all_buckets,
    mpe_output,
    root_ll_spn,
    device,
):
    """
    The function `get_ll_scores` calculates the log-likelihood scores for a given set of inputs and
    outputs using a root SPN (Sum-Product Network) and a Torch SPN (PyTorch implementation of SPN).

    :param cfg: The `cfg` parameter is a dictionary or object that contains various arguments or
    settings for the function. It is used to configure the behavior of the function
    :param root_spn: The `root_spn` parameter is a root node of a Sum-Product Network (SPN). It
    represents the top-level node of the SPN, from which all other nodes can be reached
    :param torch_spn: The `torch_spn` parameter is a PyTorch implementation of a Sum-Product Network
    (SPN) model. It is used to evaluate the log-likelihood of the SPN model given the input data
    :param all_unprocessed_data: A list or array containing the unprocessed data for all instances. Each
    instance should be a list or array of values for each variable
    :param all_outputs_for_spn: A numpy array containing the outputs of the SPN for all data points. It
    has shape (num_data_points, num_variables)
    :param all_buckets: `all_buckets` is a dictionary that contains different types of variables and
    their corresponding buckets. The keys in the dictionary represent the type of variables, and the
    values are arrays that indicate which variables belong to each type. For example,
    `all_buckets["query"]` contains an array that indicates which variables
    :param device: The `device` parameter is used to specify the device (e.g., CPU or GPU) on which the
    computations should be performed. It is typically a string indicating the device, such as "cpu" or
    "cuda:0"
    :return: The function `get_ll_scores` returns four values: `root_ll_our_nn`, `root_ll_nn`,
    `root_ll_spn`, and `mpe_output`.
    """
    # For MMAP - Each marginalized SPN will be different - Use saved outputs
    if np.array(all_buckets["unobs"][0]).sum() != 0:
        # assert cfg.task == "mmap", "Unobserved variables are not allowed for MPE"
        # query_vars = np.where(all_buckets["query"][0])[0]
        # evid_vars = np.where(all_buckets["evid"][0])[0]
        # query_plus_evid_vars = np.concatenate((query_vars, evid_vars))
        # new_scope = query_plus_evid_vars.tolist()
        # root_spn_marginalized = marginalize(
        #     root_spn,
        #     keep_scope=new_scope,
        # )
        pass
    else:
        root_spn_marginalized = root_spn
    # check if mpe_output is None or if all the test_root_ll_pgm is 0
    if (mpe_output is None or np.all(mpe_output == 0)) and cfg.pgm == "spn":
        assert cfg.pgm == "spn", "MPE solution can be computed only for SPNs"
        logger.info(f"Calculating MPE")
        array_for_spn = np.zeros_like(all_outputs_for_spn)
        array_for_spn[:] = -1
        query_bucket = np.array(all_buckets["query"])
        evid_bucket = np.array(all_buckets["evid"])
        array_for_spn[query_bucket] = np.nan
        all_unprocessed_data = np.array(all_unprocessed_data)
        array_for_spn[evid_bucket] = all_unprocessed_data[evid_bucket]
        # array_for_spn = np.where(
        #     (all_outputs_for_spn != 0.0) & (all_outputs_for_spn != 1.0),
        #     np.nan,
        #     all_outputs_for_spn,
        # )
        logger.info(f"Array for spn {array_for_spn.shape}")
        mpe_output = mpe(root_spn_marginalized, array_for_spn)
        # mpe_output = torch.from_numpy(mpe_output).double().to(cfg.device)
        # root_ll_spn = torch_spn.evaluate(mpe_output).detach().cpu().numpy()
        root_ll_spn = log_likelihood(
            root_spn_marginalized,
            mpe_output,
            return_results=False,
        )
    else:
        logger.info(f"Using precomputed MPE")
        # use precomputed mpe output
        mpe_output = mpe_output
        root_ll_spn = root_ll_spn
    if cfg.pgm in ["made", "mn", "bn"]:
        mpe_output = mpe_output[: cfg.num_test_examples]
        root_ll_spn = root_ll_spn[: cfg.num_test_examples]
    if cfg.pgm in ["mn", "bn"]:
        root_ll_spn = root_ll_spn * log(10)
    root_ll_nn, root_ll_our_nn = evaluate_nn(
        cfg, root_spn_marginalized, torch_spn, all_outputs_for_spn=all_outputs_for_spn
    )
    mpe_output = torch.FloatTensor(mpe_output).to(cfg.device).to(torch.float64)
    with torch.no_grad():
        root_ll_our_spn = torch_spn.evaluate(mpe_output)
    root_ll_our_spn = root_ll_our_spn.detach().cpu().numpy()
    if cfg.pgm in ["mn", "bn"]:
        root_ll_spn = root_ll_our_spn
    logger.info(f"Root ll SPN our implementation {np.mean(root_ll_our_spn)}")
    logger.info(f"Root ll SPN {np.mean(root_ll_spn)}")
    logger.info(f"Root ll NN {np.mean(root_ll_our_nn)}")
    logger.info(f"Root ll nn {root_ll_nn.shape}")
    logger.info(f"Root ll spn {root_ll_spn.shape}")
    mpe_output = mpe_output.detach().cpu().numpy()
    # root_ll_nn = np.average(root_ll_nn)
    if cfg.debug:
        # join the ll scores for the test set
        try:
            joint_scores = np.vstack((root_ll_spn, root_ll_our_nn))
            print(joint_scores)
        except:
            pass
    return root_ll_our_nn, root_ll_nn, root_ll_spn, mpe_output
