import sys
import time

import torch
from torch import nn

sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path
from uai_reader import UAIParser


class BinaryMNModel(nn.Module):
    def __init__(
        self,
        uai_file,
        device,
    ) -> None:
        super().__init__()
        assert uai_file.endswith(".uai"), "Only support UAI format"
        self.pgm = UAIParser(uai_file, device)
        assert self.pgm.network_type == "MARKOV", "Only support Markov Network"
        if self.pgm.pairwise_only:
            self.evaluate = self.evaluate_grids
        else:
            self.evaluate = self.evaluate_loop

    def evaluate_seq(self, x):
        ll_score = 0
        for table in self.pgm.prob_tables:
            func_vars, _, factor = table
            all_values = x[:, func_vars]
            # duplicate the values to match the shape of the factor and make second dimension 2 * size(func_vars)
            if func_vars.shape[0] == 2:
                term1 = factor[0] * (1 - all_values[:, 0]) * (1 - all_values[:, 1])
                term2 = factor[1] * (1 - all_values[:, 0]) * all_values[:, 1]
                term3 = factor[2] * all_values[:, 0] * (1 - all_values[:, 1])
                term4 = factor[3] * all_values[:, 0] * all_values[:, 1]
                ll_score += term1 + term2 + term3 + term4
            elif func_vars.shape[0] == 1:
                ll_score += (
                    factor[0] * (1 - all_values[:, 0]) + factor[1] * all_values[:, 0]
                )
        return ll_score

    def evaluate_grids(self, x):
        univariate_weights_0 = self.pgm.univariate_tables[:, 0]
        univariate_weights_1 = self.pgm.univariate_tables[:, 1]
        bivariate_weights_00 = self.pgm.bivariate_tables[:, 0]
        bivariate_weights_01 = self.pgm.bivariate_tables[:, 1]
        bivariate_weights_10 = self.pgm.bivariate_tables[:, 2]
        bivariate_weights_11 = self.pgm.bivariate_tables[:, 3]

        univariate_contributions = (
            1 - x[:, self.pgm.univariate_vars]
        ) * univariate_weights_0 + x[:, self.pgm.univariate_vars] * univariate_weights_1
        bivariate_contributions = (
            (1 - x[:, self.pgm.bivariate_vars[:, 0]])
            * (1 - x[:, self.pgm.bivariate_vars[:, 1]])
            * bivariate_weights_00.unsqueeze(0)
            + (1 - x[:, self.pgm.bivariate_vars[:, 0]])
            * x[:, self.pgm.bivariate_vars[:, 1]]
            * bivariate_weights_01.unsqueeze(0)
            + x[:, self.pgm.bivariate_vars[:, 0]]
            * (1 - x[:, self.pgm.bivariate_vars[:, 1]])
            * bivariate_weights_10.unsqueeze(0)
            + x[:, self.pgm.bivariate_vars[:, 0]]
            * x[:, self.pgm.bivariate_vars[:, 1]]
            * bivariate_weights_11.unsqueeze(0)
        )
        loss_val = torch.sum(univariate_contributions, dim=1) + torch.sum(
            bivariate_contributions, dim=1
        )

        return loss_val

    def evaluate_loop(self, x):
        ll_scores = torch.zeros(x.shape[0], device=x.device)
        for table in self.pgm.prob_tables:
            func_vars, domains, factor = table
            all_values = x[:, func_vars]

            # Precompute all possible binary combinations for the current table
            num_vars = len(func_vars)
            binary_combinations = torch.tensor(
                [
                    [(j >> k) & 1 for k in range(num_vars - 1, -1, -1)]
                    for j in range(2**num_vars)
                ],
                dtype=torch.float32,
                device=x.device,
            )

            # Compute the product terms using broadcasting
            expanded_values = all_values.unsqueeze(0).expand(2**num_vars, -1, -1)
            inverse_values = 1 - expanded_values
            selected_values = torch.where(
                binary_combinations.unsqueeze(1) == 1, expanded_values, inverse_values
            )
            product_terms = torch.prod(selected_values, dim=2)

            # Multiply by CPD values and sum over all combinations for each data point
            ll_scores += torch.matmul(factor, product_terms)

        return ll_scores
