import copy
import random

import torch
from loguru import logger
from torch.distributions import Categorical
from torch.distributions.categorical import Categorical
from torch.utils.data import Dataset


def create_buckets_one_query_per(n, num_in_buckets, device="cuda"):
    """
    Distributes numbers from 0 to n-1 into buckets with a faster random assignment of indices to True.

    Args:
        n (int): The range of numbers (0 to n-1) to be divided.
        num_in_buckets (list): List of integers representing the number of variables
                               in each bucket.

    Returns:
        dict: A dictionary where each key corresponds to a bucket ('evid', 'query', 'unobs')
              and the value is a boolean tensor indicating the randomly selected indices.

    Example:
        >>> n = 10
        >>> num_in_buckets = [4, 0, 3]
        >>> buckets = create_buckets_random_fast(n, num_in_buckets)
        >>> print(buckets)
        {'evid': tensor([False, False, False, False, True, True, True, True, False, True]),
         'query': tensor([True, False, False, False, False, False, False, False, False, False]),
         'unobs': tensor([False, True, True, True, False, False, False, False, True, False])}
    """
    # Generate a random permutation of indices from 0 to n-1
    random_indices = torch.randperm(n)

    # Initialize a dictionary to hold the buckets
    buckets = {
        "evid": torch.zeros(n, dtype=torch.bool, device=device),
        "query": torch.zeros(n, dtype=torch.bool, device=device),
        "unobs": torch.zeros(n, dtype=torch.bool, device=device),
    }

    # Iterate through each bucket and assign random indices to True
    start_idx = 0
    for bucket_name, num_vars in zip(buckets.keys(), num_in_buckets):
        # Check if the bucket has zero variables
        if num_vars == 0:
            continue

        # Assign the first num_vars indices to True in the bucket
        end_idx = start_idx + num_vars
        selected_indices = random_indices[start_idx:end_idx]
        buckets[bucket_name][selected_indices] = True
        start_idx = end_idx

    return buckets


def get_bucket_from_unique_set(unique_buckets, num_unique_buckets, task, device="cuda"):
    random_idx = random.randint(0, num_unique_buckets - 1)
    buckets = {
        "evid": torch.from_numpy(unique_buckets["evid"][random_idx]).bool().to(device),
        "query": torch.from_numpy(unique_buckets["query"][random_idx])
        .bool()
        .to(device),
    }
    if task == "mpe":
        buckets["unobs"] = (
            torch.from_numpy(unique_buckets["unobs"][0]).bool().to(device)
        )
    elif task == "mmap":
        buckets["unobs"] = (
            torch.from_numpy(unique_buckets["unobs"][random_idx]).bool().to(device)
        )
    return buckets


def create_buckets_all_queries_per(n, task="mpe", device="cuda"):
    # Calculate the minimum amount for each part based on 10 percent
    min_amount = max(1, n // 10)

    # Randomly distribute the remaining amount
    remaining = n - 3 * min_amount
    random_addition = random.randint(0, remaining)

    # Divide n into two parts with randomness
    num_var_in_query = 2 * min_amount + random_addition
    num_var_in_evid = n - num_var_in_query

    # Adjust the division for tasks other than "mpe"
    if task != "mpe":
        half_evid = num_var_in_evid // 2
        num_var_in_evid, num_var_in_unobs = half_evid, half_evid
    else:
        num_var_in_unobs = 0

    # Ensure that the total number of variables is n
    num_var_in_evid += n - (num_var_in_query + num_var_in_evid + num_var_in_unobs)

    # Prepare bucket values
    num_in_buckets = [num_var_in_evid, num_var_in_query, num_var_in_unobs]
    return create_buckets_one_query_per(n, num_in_buckets, device=device)


def create_buckets_mnist(n, evidence_prob, horizontal=True):
    """
    The function `create_buckets_mnist` creates buckets for MNIST data based on evidence probability and
    whether the line is horizontal or vertical.

    :param n: The parameter `n` represents the total number of elements in the dataset
    :param evidence_prob: The `evidence_prob` parameter represents the probability of a pixel being part
    of the evidence in each bucket. It determines the number of pixels that are set to 0 or 1 in the
    binary arrays used to create the buckets. For example, if `evidence_prob` is 0
    :param horizontal: A boolean parameter that determines whether the evidence and query buckets are
    divided by a horizontal line or a vertical line. If horizontal is set to True, the buckets will be
    divided by a horizontal line. If horizontal is set to False, the buckets will be divided by a
    vertical line, defaults to True (optional)
    :return: The function `create_buckets_mnist` returns a dictionary `buckets` containing three keys:
    "evid", "query", and "unobs". The values associated with each key are boolean tensors representing
    the indices of the corresponding buckets.
    """

    import torch

    def create_binary_array_with_verticle_line(size, evidence_prob):
        width = int(size**0.5)

        # Create a new 2D tensor with 0s on the left part and 1s on the right part
        binary_tensor = torch.zeros((width, width))
        # Determine the number of columns with 0s on the left side based on evidence_prob
        num_zero_cols = int(width * evidence_prob)
        binary_tensor[:, :num_zero_cols] = 0

        # Set the remaining columns on the left side to 1
        binary_tensor[:, num_zero_cols:] = 1
        # Flatten the 2D tensor into a 1D tensor
        result_tensor = binary_tensor.flatten()

        return result_tensor

    def create_binary_array_with_horizontal_line(size, evidence_prob):
        width = int(size**0.5)

        # Create a new 2D tensor with 0s on the top part and 1s on the bottom part
        binary_tensor = torch.zeros((width, width))

        # Determine the number of rows with 0s on the top side based on evidence_prob
        num_zero_rows = int(width * evidence_prob)
        binary_tensor[:num_zero_rows, :] = 0

        # Set the remaining rows on the top side to 1
        binary_tensor[num_zero_rows:, :] = 1

        # Flatten the 2D tensor into a 1D tensor
        result_tensor = binary_tensor.flatten()

        return result_tensor

    indices = torch.arange(n)
    if horizontal:
        bucket_indices = create_binary_array_with_horizontal_line(n, evidence_prob)
    else:
        bucket_indices = create_binary_array_with_verticle_line(n, evidence_prob)
    unique_buckets, bucket_counts = torch.unique(bucket_indices, return_counts=True)
    buckets = {}
    for i, bucket_idx in enumerate(unique_buckets):
        bucket = indices[(bucket_indices == bucket_idx).nonzero(as_tuple=False)]
        if i == 0:
            key = "evid"
        elif i == 1:
            key = "query"
        elif i == 2:
            key = "unobs"
        # Example tensor
        # buckets[key] = bucket
        # Convert indices to boolean tensor
        bool_tensor = torch.zeros((n), dtype=torch.bool)
        bool_tensor[bucket] = True
        buckets[key] = bool_tensor
    return buckets


def process_bucket_for_nn_discrete(sample, buckets, embedding_dict):
    """
    Process bucket for neural network with discrete samples.
    Handles both single example (1D tensor) and batch of examples (2D tensor).
    """

    # Check if the sample is a batch (2D) or a single example (1D)
    if sample.dim() == 1:
        n_vars = sample.size(0)
        final_sample = torch.zeros(n_vars * 2, dtype=sample.dtype, device=sample.device)

        final_sample[0::2] = sample.double()  # Assign even indices
        final_sample[1::2] = 1 - sample.double()  # Assign odd indices

        for key, value in (("query", 0), ("unobs", 1)):
            mask = buckets[key]
            final_sample[::2][mask] = final_sample[1::2][mask] = value

    elif sample.dim() == 2:
        num_samples, n_vars = sample.size()
        final_sample = torch.zeros(
            num_samples, n_vars * 2, dtype=sample.dtype, device=sample.device
        )

        final_sample[:, 0::2] = sample.double()  # Assign even indices to all samples
        final_sample[:, 1::2] = 1 - sample.double()  # Assign odd indices to all samples

        # Apply masks for 'query' and 'unobserved' buckets
        for key, value in (("query", 0), ("unobs", 1)):
            # same mask for all the samples
            mask = buckets[key][0]
            final_sample[:, ::2][:, mask] = final_sample[:, 1::2][:, mask] = value

    else:
        raise ValueError(f"Invalid sample dimension: {sample.dim()}")
    return final_sample


def process_bucket_for_nn_continuousConst(sample, buckets, embedding_dict):
    """
    Process bucket for neural network with continuous constant embeddings.
    Handles both single example (1D tensor) and batch of examples (2D tensor).
    The input tensor 'sample' and 'buckets' are expected to be of the same shape.
    Each value in 'embedding_dict' is a single value (scalar).
    """

    # Validate the dimensionality of the sample
    if sample.dim() not in [1, 2]:
        raise ValueError("Sample tensor must be either 1D or 2D.")

    # Initialize the final sample with the 'query' embedding
    final_sample = torch.full_like(
        sample, embedding_dict["query"].item(), device=sample.device
    )

    # Advanced indexing for 'evid'
    evid_mask = buckets["evid"]
    final_sample[evid_mask & (sample == 0)] = embedding_dict["evid_0"]
    final_sample[evid_mask & (sample == 1)] = embedding_dict["evid_1"]

    # Assign embedding values for 'unobs'
    final_sample[buckets["unobs"]] = embedding_dict["unobs"]

    return final_sample


def process_bucket_for_nn_continuous_embed(sample, buckets, embedding_dict):
    # We don't need to do anything here since the embedding is taken care of by the embedding layer in the model
    return sample


def process_bucket_for_transformer(sample, buckets):
    """
    Process the buckets based on the given sample tensor for a single row.

    cfg:
        sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
        buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

    Returns:
        torch.Tensor: Processed tensor of the same shape as the input sample,
                      where the buckets have been modified according to the provided rules.

    Example:
        >>> sample = torch.tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        >>> buckets = [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]
        >>> final_sample = process_buckets_single_row(sample, buckets)
        >>> print(final_sample)
        tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    """

    input_data = sample.clone()

    # Set both query and unobs values to -1
    # We do this to show that these values are not observed
    # We can also add attention mask to transformer to ignore unobs values
    input_data[buckets["query"]] = -1
    input_data[buckets["unobs"]] = -1
    return input_data


def create_attention_mask_for_transformer(buckets):
    attention_mask = torch.zeros_like(buckets["evid"], dtype=torch.float)
    # Set the positions of unobs to -inf
    # Since we don't know the value of unobs, we don't want to attend to it
    attention_mask[buckets["unobs"]] = float("-inf")

    return attention_mask


def create_continous_embedding_dict(device):
    embed_dict = {}
    embed_dict["unobs"] = torch.tensor([0.33]).to(device)
    embed_dict["query"] = torch.tensor([0.66]).to(device)
    embed_dict["evid_0"] = torch.tensor([0.0]).to(device)
    embed_dict["evid_1"] = torch.tensor([1.0]).to(device)
    return embed_dict


class MMAPBaseDataset(Dataset):
    """
    Base dataset class for MMAP datasets, providing common functionalities.
    """

    def __init__(
        self,
        data,
        spn=None,
        model_type="nn",
        data_device="cuda",
        input_type="data",
        embedding_type="discrete",
        extracted_features=None,
        num_layers=None,
    ):
        self.data = data.double().to(data_device)
        self.validate_input_type(input_type, spn)
        self.spn = spn
        self.model_type = model_type
        self.input_type = input_type
        self.embedding_type = embedding_type
        if embedding_type == "continuousConst":
            self.embedding_dict = create_continous_embedding_dict(data_device)
        else:
            self.embedding_dict = None
        self.process_bucket_function = self.determine_process_function(model_type)

    def __len__(self):
        return len(self.data)

    @staticmethod
    def validate_input_type(input_type, spn):
        """
        Validates the input type and ensures SPN is provided when required.
        """
        if input_type in ["spn", "dataSpn"] and spn is None:
            raise ValueError("SPN must be provided if input_type is spn or data_spn")

    def determine_process_function(self, model_type):
        """
        Determines the appropriate processing function based on the model type.
        """
        if model_type in ["nn", "vae"]:
            if self.embedding_type in ["discrete", "continuousSin"]:
                return process_bucket_for_nn_discrete
            elif self.embedding_type == "continuousConst":
                return process_bucket_for_nn_continuousConst
            elif self.embedding_type in ["continuousEmbed"]:
                return process_bucket_for_nn_continuous_embed
            else:
                raise ValueError(f"Invalid embedding type: {self.embedding_type}")
        elif model_type == "transformer":
            # We want to use the learnable embedding layer in the transformer model
            return process_bucket_for_nn_continuous_embed
        else:
            raise ValueError(f"Invalid model type: {model_type}")

    def create_buckets(self, sample, index):
        """
        Creates buckets for the given sample. To be implemented in subclasses.
        """
        raise NotImplementedError("This method should be implemented by subclasses")

    def __getitem__(self, index):
        sample = self.data[index]
        if not self.use_sampled_buckets:
            buckets = self.create_buckets(sample, index)
            final_sample = self.process_bucket_function(
                sample, buckets, self.embedding_dict
            )
            attention_mask = self.create_attention_mask(buckets, sample)
            return {
                "index": index,
                "initial": sample,
                "attention_mask": attention_mask,
                "data": final_sample,
                # "dataSpn": self.get_data_spn(sample, buckets)
                # if self.input_type in ["spn", "dataSpn"]
                # else None,
                **buckets,
            }
        elif self.use_sampled_buckets:
            buckets = self.unique_buckets
            # Verify if repeating sample and attention_mask is necessary
            repeated_sample = sample.repeat(self.num_unique_buckets, 1)
            attention_mask = self.create_attention_mask(buckets, sample).repeat(
                self.num_unique_buckets, 1
            )
            # Apply the process_bucket_function to each bucket and the original sample
            final_samples = self.process_bucket_function(
                repeated_sample, buckets, self.embedding_dict
            )

            return {
                "index": index,
                "initial": repeated_sample,
                "attention_mask": attention_mask,
                "data": final_samples,
                # Ensure that 'buckets' is a dictionary with unique keys
                **buckets,
            }

    def create_attention_mask(self, buckets, sample):
        """
        Creates an attention mask based on the model type.
        """
        if self.model_type == "transformer":
            return create_attention_mask_for_transformer(buckets)
        return torch.zeros_like(sample, dtype=torch.float)

    def get_data_spn(self, sample, buckets):
        """
        Retrieves SPN input features for the given sample.
        """
        sample_for_spn = sample.clone().unsqueeze(0)
        return self.spn.get_input_features(sample_for_spn, buckets).squeeze()


class MMAPTrainDataset(MMAPBaseDataset):
    """
    Dataset class for MMAP training data.
    """

    def __init__(
        self,
        data,
        num_var_in_buckets: list = None,
        same_bucket_for_iter: bool = False,
        use_single_model: bool = False,
        task: str = "mpe",
        use_sampled_buckets: bool = False,
        unique_buckets=None,
        **kwargs,
    ):
        super().__init__(data, **kwargs)
        self.num_var_in_buckets = num_var_in_buckets
        self.same_bucket_for_iter = same_bucket_for_iter
        self.use_single_model = use_single_model
        self.use_sampled_buckets = use_sampled_buckets
        self.unique_buckets = unique_buckets
        self.task = task
        if unique_buckets is not None:
            self.num_unique_buckets = len(unique_buckets["query"])
        else:
            self.num_unique_buckets = -1

    def create_buckets(self, sample, index):
        """
        Creates buckets for training data based on the data distribution.
        """
        if self.same_bucket_for_iter:
            return {
                "evid": torch.zeros(sample.shape[0], dtype=torch.bool),
                "query": torch.zeros(sample.shape[0], dtype=torch.bool),
                "unobs": torch.zeros(sample.shape[0], dtype=torch.bool),
            }
        if self.use_sampled_buckets:
            # sample buckets from a fixed set of buckets - self.unique_buckets
            buckets = get_bucket_from_unique_set(
                self.unique_buckets, self.num_unique_buckets, self.task, sample.device
            )
        if self.use_single_model:
            buckets = create_buckets_all_queries_per(
                sample.shape[0], self.task, sample.device
            )
        else:
            buckets = create_buckets_one_query_per(
                sample.shape[0], self.num_var_in_buckets, sample.device
            )
        if not self.same_bucket_for_iter:
            self.ensure_all_buckets(buckets, sample.shape[0])
        return buckets

    @staticmethod
    def ensure_all_buckets(buckets, sample_size):
        """
        Ensures all required buckets are present, initializing them if necessary.
        """
        for bucket_name in ["evid", "query", "unobs"]:
            buckets.setdefault(bucket_name, torch.zeros(sample_size, dtype=torch.bool))


class MMAPTestDataset(MMAPBaseDataset):
    """
    Dataset class for MMAP test data.
    """

    def __init__(self, data, buckets, num_var_in_buckets, data_device, **kwargs):
        super().__init__(data, **kwargs)
        self.buckets = {
            key: torch.from_numpy(value).bool().to(data_device)
            for key, value in buckets.items()
        }
        self.bucket_names = (
            ["evid", "query", "unobs"]
            if "unobs" in buckets and len(buckets["unobs"]) > 0
            else ["evid", "query"]
        )
        self.num_var_in_buckets = num_var_in_buckets
        self.use_sampled_buckets = False

    def create_buckets(self, sample, index):
        """
        Retrieves the appropriate buckets for the given test sample.
        """
        return {
            bucket_name: self.buckets[bucket_name][index]
            for bucket_name in self.bucket_names
        }


def collate_sampled_buckets(batch):
    """
    Custom collate function to process batches of data where each item is a dictionary.
    Ensures the output for each key is a 2D tensor.
    Concatenates 1D tensors along a new dimension and stacks 2D tensors along the first dimension.

    Args:
        batch (list of dicts): List of dictionaries with tensors.

    Returns:
        dict: A dictionary with 2D tensors.
    """
    # Initialize a dictionary to hold the collated data
    collated_batch = {}

    # Iterate over keys in the dictionary
    for key in batch[0].keys():
        # Check the dimension of the first item for this key to determine processing
        if batch[0][key].dim() == 1:
            # If the tensors are 1D, stack them along a new dimension to make them 2D
            collated_batch[key] = torch.stack([item[key] for item in batch], dim=0)
        elif batch[0][key].dim() == 2:
            # If the tensors are 2D, concatenate them along the first dimension
            collated_batch[key] = torch.cat([item[key] for item in batch], dim=0)
        else:
            raise ValueError("Tensors must be either 1D or 2D.")

    return collated_batch
