import re

import torch


class UAIParser:
    def __init__(self, file_path, device):
        self.file_path = file_path
        self.device = device
        self.eps = 1e-10
        self.regex_pattern = re.compile(r"\s+")
        self.pairwise_only = True
        self.parse_file()
        self.bivariate_tables = torch.stack(self.bivariate_tables, dim=0)
        self.univariate_tables = torch.stack(self.univariate_tables, dim=0)
        self.univariate_vars = torch.stack(self.univariate_vars).squeeze()
        self.bivariate_vars = torch.stack(self.bivariate_vars)
        
    def read_next_token(self, file_iter):
        return next(file_iter, None)

    def parse_file(self):
        with open(self.file_path, "r") as file:
            # Compile regex outside of the loop
            file_iter = (
                token
                for line in file
                for token in self.regex_pattern.split(line)
                if token
            )

            # Parse the type of network and other properties
            (
                self.network_type,
                self.num_vars,
            ) = (self.read_next_token(file_iter) for _ in range(2))
            self.num_vars = int(self.num_vars)

            # Parse domain sizes for each variable
            self.domain_sizes = torch.tensor(
                [int(self.read_next_token(file_iter)) for _ in range(self.num_vars)],
                device=self.device,
            )

            self.num_cliques = int(self.read_next_token(file_iter))
            # Efficiently parse cliques
            self.cliques = [
                torch.tensor(
                    [
                        int(self.read_next_token(file_iter))
                        for _ in range(int(self.read_next_token(file_iter)))
                    ],
                    dtype=torch.long,
                    device=self.device,
                )
                for _ in range(self.num_cliques)
            ]

            # Parse probability tables and separate univariate and bivariate tables
            (
                self.prob_tables,
                self.univariate_tables,
                self.univariate_vars,
                self.bivariate_tables,
                self.bivariate_vars,
            ) = ([], [], [], [], [])
            for clique in self.cliques:
                table_size = int(self.read_next_token(file_iter))
                table_values = [
                    max(float(self.read_next_token(file_iter)), self.eps)
                    for _ in range(table_size)
                ]  # Ensure no zero values
                table = torch.log(torch.tensor(table_values, device=self.device))
                domains_this_table = self.domain_sizes[clique]
                self.prob_tables.append((clique, domains_this_table, table))

                if len(clique) == 1:
                    self.univariate_tables.append(table)
                    self.univariate_vars.append(clique)
                elif len(clique) == 2:
                    self.bivariate_tables.append(table)
                    self.bivariate_vars.append(clique)
                else:
                    self.pairwise_only = False

    def get_parsed_data(self):
        return {
            "type": self.network_type,
            "num_vars": self.num_vars,
            "domain_sizes": self.domain_sizes,
            "num_cliques": self.num_cliques,
            "cliques": self.cliques,
            "prob_tables": self.prob_tables,
            "univariate_tables": self.univariate_tables,
            "univariate_vars": self.univariate_vars,
            "bivariate_tables": self.bivariate_tables,
            "bivariate_vars": self.bivariate_vars,
        }
