from itertools import product

import torch
from loguru import logger
from torch.distributions import Categorical
from torch.utils.data import DataLoader, Dataset


class CustomDataset(Dataset):
    def __init__(self, data, buckets):
        self.initial_data = data
        self.buckets = buckets
        if buckets is not None:
            # If buckets are provided, preprocess the data accordingly
            final_samples = []
            for i in range(self.initial_data.shape[0]):
                this_bucket = {
                    "evid": torch.from_numpy(self.buckets["evid"][i]),
                    "query": torch.from_numpy(self.buckets["query"][i]),
                    "unobs": torch.from_numpy(self.buckets["unobs"][i]),
                }
                final_samples.append(
                    process_bucket_for_nn(self.initial_data[i], this_bucket)
                )
            self.final_sample = torch.stack(final_samples)
        self.length = self.initial_data.shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Convert idx to base 4 to get the combination index for each variable
        if self.buckets is None:
            buckets = create_buckets(self.initial_data.shape[1], [0.4, 0.3, 0.3])
            sample = self.initial_data[idx]
            final_sample = process_bucket_for_nn(sample, buckets)
        else:
            final_sample = self.final_sample[idx]
        return final_sample


def get_custom_dataloader(data, buckets=None, batch_size=32, shuffle=True):
    num_variables = data.shape[1]
    logger.info(f"Generating custom dataloader for {num_variables} variables")
    dataset = CustomDataset(data, buckets=buckets)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    logger.info(f"Generated custom dataloader for {num_variables} variables")
    return dataloader


def create_buckets(n, probabilities):
    """
    Divide the numbers from 0 to n-1 into buckets based on probabilities.

    cfg:
        n (int): The range of numbers (0 to n-1) to be divided.
        probabilities (list): A list of probabilities for each bucket.
                              The probabilities should sum up to 1.0.

    Returns:
        list: A list of buckets, where each bucket contains the numbers
              from 0 to n-1 that belong to that specific bucket.

    Example:
        >>> n = 10
        >>> probabilities = [0.4, 0.3, 0.3]
        >>> buckets = create_buckets(n, probabilities)
        >>> print(buckets)
        [[0, 3, 5, 7, 9], [1, 2, 4], [6, 8]]
    """
    distribution = Categorical(torch.tensor(probabilities))
    samples = distribution.sample(torch.Size([n]))
    indices = torch.arange(n)
    bucket_indices = torch.zeros_like(indices)
    # Assign each number to its respective bucket based on the sampled value
    bucket_indices.scatter_(0, indices, samples)
    unique_buckets, bucket_counts = torch.unique(bucket_indices, return_counts=True)
    buckets = {}
    for i, bucket_idx in enumerate(unique_buckets):
        bucket = indices[(bucket_indices == bucket_idx).nonzero(as_tuple=False)]
        if i == 0:
            key = "evid"
        elif i == 1:
            key = "query"
        elif i == 2:
            key = "unobs"
        bool_tensor = torch.zeros((n), dtype=torch.bool)
        bool_tensor[bucket] = True
        buckets[key] = bool_tensor.numpy()

    for keys in ["evid", "query", "unobs"]:
        if keys not in buckets:
            bool_tensor = torch.zeros((n), dtype=torch.bool)
            buckets[keys] = bool_tensor.numpy()
    return buckets


def process_bucket_for_nn(sample, buckets):
    """
    Process the buckets based on the given sample tensor for a single row.

    cfg:
        sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
        buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

    Returns:
        torch.Tensor: Processed tensor of the same shape as the input sample,
                      where the buckets have been modified according to the provided rules.

    Example:
        >>> sample = torch.tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        >>> buckets = [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]
        >>> final_sample = process_buckets_single_row(sample, buckets)
        >>> print(final_sample)
        tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    """
    if type(buckets["evid"]) != torch.Tensor:
        buckets["evid"] = torch.from_numpy(buckets["evid"])
        buckets["query"] = torch.from_numpy(buckets["query"])
        buckets["unobs"] = torch.from_numpy(buckets["unobs"])
    final_sample = torch.zeros_like(sample).repeat(2)

    # Handle the first bucket - evidence
    # if not torch.isnan(buckets['evid']).any():
    # indices =
    indices = torch.nonzero(buckets["evid"]).flatten()
    zero_evid_vars = sample[indices] == 0
    final_sample[2 * indices] = zero_evid_vars.double()
    final_sample[2 * indices + 1] = (~zero_evid_vars).double()

    # if not torch.isnan(buckets['query']).any():
    # Handle the second bucket
    indices = torch.nonzero(buckets["query"])
    indices = indices.flatten()
    final_sample[2 * indices] = 0
    final_sample[2 * indices + 1] = 0

    # if not torch.isnan(buckets['unobs']).any():
    # Handle the third bucket
    indices = torch.nonzero(buckets["unobs"])
    indices = indices.flatten()

    final_sample[2 * indices] = 1
    final_sample[2 * indices + 1] = 1
    return final_sample


if __name__ == "__main__":
    # Example usage
    num_variables = 3  # Change this as needed
    dataloader = get_custom_dataloader(num_variables)
    for batch in dataloader:
        print(batch)
    print(len(dataloader.dataset))
