import numpy as np
from deeprob.spn.algorithms.inference import log_likelihood, mpe
from deeprob.spn.algorithms.structure import marginalize
from loguru import logger
from methods import (mle_spn, mle_spn_torch, mpe_spn, random_initialization,
                     sequential_spn, sequential_spn_torch)
from tqdm import tqdm


def hill_climbing_search(
    args,
    data,
    prev_outputs,
    query_var_bool,
    unobs_var_bool,
    evid_var_bool,
    library_spn,
    torch_spn,
    buckets,
):
    num_query_vars = np.sum(query_var_bool[0])
    num_unobs_vars = np.sum(unobs_var_bool[0])
    if np.array(unobs_var_bool).sum() != 0:
        # If there are unobserved variables, we need to marginalize them out
        # assert args.method == "mmap", "Unobserved variables are not allowed for MPE"
        assert NotImplementedError("Unobserved variables are not allowed for MPE")
        # query_vars = np.where(query_var_bool)[0]
        # evid_vars = np.where(evid_var_bool)[0]
        # query_plus_evid_vars = np.concatenate((query_vars, evid_vars))
        # new_scope = query_plus_evid_vars.tolist()
        # library_spn_marginalized = marginalize(
        #     library_spn,
        #     keep_scope=new_scope,
        # )
    else:
        library_spn_marginalized = library_spn
    output_samples = []
    output_ll_scores = []

    if args.initialization == "random":
        initial_data_points = random_initialization(
            data, query_var_bool, unobs_var_bool, num_query_vars, num_unobs_vars
        )
    elif args.initialization == "spn_approx":
        if "mpe_output" in prev_outputs:
            mpe_output = prev_outputs["mpe_output"]
        else:
            mpe_output = mpe_spn(
                prev_outputs, query_var_bool, evid_var_bool, library_spn_marginalized
            )
        initial_data_points = data.copy()
        initial_data_points[query_var_bool] = mpe_output[query_var_bool]
        if num_unobs_vars != 0:
            initial_data_points[unobs_var_bool] = np.random.randint(
                0, 2, size=(data.shape[0], num_unobs_vars)
            )
    elif args.initialization == "nn_approx":
        initial_data_points = data.copy()
        # Threshold initial data points
        initial_data_points[initial_data_points >= 0.5] = 1
        initial_data_points[initial_data_points < 0.5] = 0
    elif args.initialization == "mle":
        # initial_data_points = mle_spn(
        #     prev_outputs, query_var_bool, evid_var_bool, library_spn_marginalized
        # )
        initial_data_points = mle_spn_torch(
            prev_outputs, query_var_bool, evid_var_bool, torch_spn
        )
    elif args.initialization == "sequential":
        initial_data_points = sequential_spn_torch(
            prev_outputs, query_var_bool, evid_var_bool, torch_spn
        )
    else:
        raise NotImplementedError
    # assert that none of the initial data points have nan values
    assert not np.isnan(initial_data_points).any()
    initial_ll_scores = log_likelihood(
        library_spn_marginalized, initial_data_points, return_results=False
    )

    if args.no_local_search:
        return (
            initial_data_points,
            initial_ll_scores,
            initial_data_points,
            initial_ll_scores,
            library_spn_marginalized,
        )
    # Define the number of iterations for early stopping
    early_stopping_patience = 5
    # Define the distance threshold
    distance_threshold = 0.05
    no_improvement_count = 0
    for _ in tqdm(range(args.num_samples)):
        best_sample = initial_data_points.copy()
        best_ll_score = initial_ll_scores.copy()
        sample_one_iter_vectorized(
            library_spn_marginalized, query_var_bool, best_sample, best_ll_score
        )
        # Check if ll scores have improved
        if (
            abs(np.mean(initial_ll_scores) - np.mean(best_ll_score))
            < distance_threshold
        ):
            no_improvement_count = 0
        else:
            no_improvement_count += 1
        # Check if early stopping criteria is met
        # if no_improvement_count >= early_stopping_patience:
        #     logger.info(f"Early stopping at iteration {_}")
        #     logger.info(f"Mean ll score {np.mean(ll_scores)}")
        #     logger.info(f"Mean best ll score {np.mean(best_ll_score)}")
        #     break

    output_samples = best_sample
    output_ll_scores = best_ll_score

    output_samples = np.array(output_samples)
    output_ll_scores = np.array(output_ll_scores)

    return (
        initial_data_points,
        initial_ll_scores,
        output_samples,
        output_ll_scores,
        library_spn_marginalized,
    )


def sample_one_iter(library_spn, query_var_bool, all_best_sample, all_best_ll_score):
    for idx in tqdm(range(all_best_sample.shape[0])):
        all_indices_to_sample = np.where(query_var_bool[idx])[0]
        best_sample = all_best_sample[idx]
        best_ll_score = all_best_ll_score[idx]
        for each_possible_flip in all_indices_to_sample:
            flipped_data_points = best_sample.copy().reshape(1, -1)
            flipped_data_points[:, each_possible_flip] = (
                1 - flipped_data_points[:, each_possible_flip]
            )
            flipped_ll_scores = log_likelihood(
                library_spn, flipped_data_points, return_results=False
            )
            if flipped_ll_scores > best_ll_score:
                best_sample = flipped_data_points
                best_ll_score = flipped_ll_scores
        all_best_sample[idx] = best_sample
        all_best_ll_score[idx] = best_ll_score


def sample_one_iter_vectorized(
    library_spn, query_var_bool, all_best_sample, all_best_ll_score
):
    # get the indices per row where query_var_bool is true, size should be num_examples x num_query_vars
    num_examples = query_var_bool.shape[0]
    query_vars_each_example = []
    for idx in range(query_var_bool.shape[0]):
        query_var_this_example = np.where(query_var_bool[idx])[0]
        query_vars_each_example.append(query_var_this_example)
    query_vars_each_example = np.array(query_vars_each_example)
    num_flips = query_vars_each_example.shape[1]
    for flip_idx in range(num_flips):
        flipped_data_points = all_best_sample.copy()
        each_possible_flip = query_vars_each_example[:, flip_idx]
        # Create an array of row indices
        row_indices = np.arange(flipped_data_points.shape[0])

        # Flip the element in each row specified by each_possible_flip
        flipped_data_points[row_indices, each_possible_flip] = (
            1 - flipped_data_points[row_indices, each_possible_flip]
        )

        flipped_ll_scores = log_likelihood(
            library_spn, flipped_data_points, return_results=False
        )

        better_indices = np.where(flipped_ll_scores > all_best_ll_score)[0]
        all_best_sample[better_indices] = flipped_data_points[better_indices]
        all_best_ll_score[better_indices] = flipped_ll_scores[better_indices]
