import math
from typing import List

import torch
import torch.nn as nn
from loguru import logger
from project_utils.losses import *
from project_utils.model_utils import get_spn_mpe_output
from project_utils.profiling import pytorch_profile


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class TransformerEncoder(nn.Module):
    def __init__(
        self, cfg, feature_size, d_model, nhead, num_layers, library_spn, dropout=0.1
    ):
        super(TransformerEncoder, self).__init__()
        # self.embedding = nn.Linear(1, d_model)  # Binary to d_model embedding
        # Batch normalization layer
        # self.batch_norm = nn.BatchNorm1d(feature_size)
        self.pos_embedding = PositionalEncoding(d_model)
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, dim_feedforward=128, batch_first=True, nhead=nhead, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.library_spn = library_spn
        # Linear layer for prediction
        self.fc = nn.Linear(d_model, 1)
        self.cfg = cfg

        # Dropout layer
        # self.dropout = nn.Dropout(dropout)

    def forward(self, src, attention_mask):
        # src = self.batch_norm(src)
        # pos_indices = torch.arange(src.shape[1], device=src.device).unsqueeze(0)
        # src = src.unsqueeze(2)
        # embedding = self.embedding(src)
        src = self.pos_embedding(src)
        output = self.transformer_encoder(
            src,
        )  # mask=attention_mask[0])
        # output = self.dropout(output)
        # convert a batchsize x seq_len x d_model tensor to batchsize x seq_len x 1 tensor using linear layer
        output = self.fc(output)

        # Remove the last dimension to get [batch_size, seq_len]
        output = output.squeeze(-1)
        return output

    def process_buckets_single_row_for_spn(self, nn_output, true, buckets):
        """
        Process the buckets based on the given sample tensor for a single row.

        cfg:
            sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
            buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input sample,
                        where the buckets have been modified according to the provided rules.
        """

        final_sample = nn_output.clone().requires_grad_(True)
        # Handle the first bucket
        # if buckets['evid'] is not torch.tensor([-1]):
        # indices = torch.nonzero(buckets['evid'])
        indices = buckets["evid"]
        final_sample[indices] = true[indices]

        # Query is already present
        # indices = torch.nonzero(buckets['unobs'])
        indices = buckets["unobs"]
        final_sample[indices] = float("nan")
        return final_sample

    def train_iter(
        self,
        spn,
        data,
        data_spn,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = data
        # elif self.cfg.input_type == "spn":
        #     input_to_model = data_spn
        # elif self.cfg.input_type == "dataSpn":
        #     # First half is spn data, second half is data
        #     input_to_model = torch.cat((data_spn, data), dim=1)
        else:
            raise ValueError(
                f"Input type {self.cfg.input_type} not supported for Transformer model"
            )

        model_output = self(input_to_model, attention_mask)
        l2_loss = torch.nn.MSELoss()
        bce_loss = torch.nn.BCELoss()
        if self.cfg.activation_function == "sigmoid":
            output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            output = m(model_output)
        if torch.isnan(output).any():
            logger.info(output)
            raise(ValueError("Nan in output"))


        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_spn = self.process_buckets_single_row_for_spn(
            nn_output=output, true=initial_data, buckets=buckets
        )

        final_func_value = spn.evaluate(output_for_spn)
        if not self.cfg.no_log_loss:
            loss_from_spn = -final_func_value
        else:
            loss_from_spn = -torch.exp(final_func_value)

        loss = loss_from_spn

        if self.cfg.add_supervised_loss:
            outputs_for_spn_np = output_for_spn.detach().cpu().numpy()
            query_bucket_np = query_bucket.detach().cpu().numpy()
            mpe_outputs = get_spn_mpe_output(
                self.library_spn, outputs_for_spn_np, query_bucket_np
            )
            mpe_outputs = torch.tensor(mpe_outputs, device=self.cfg.device)
            query_spn_outputs = mpe_outputs[query_bucket]
            query_nn_outputs = output[query_bucket]
            supervised_loss = l2_loss(query_nn_outputs, query_spn_outputs)
            loss += self.cfg.supervised_loss_lambda * supervised_loss

        if self.cfg.add_evid_loss:
            evidence_output = output[evid_bucket]
            evidence_true = initial_data[evid_bucket]
            evid_loss = l2_loss(evidence_output, evidence_true)
            loss += self.cfg.evid_lambda * evid_loss

        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss
        loss = loss.mean()
        return loss

    def validate_iter(
        self,
        spn,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_spn,
        all_buckets,
        data,
        data_spn,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = data
        # elif self.cfg.input_type == "spn":
        #     input_to_model = data_spn
        # elif self.cfg.input_type == "dataSpn":
        #     # First half is spn data, second half is data
        #     input_to_model = torch.cat((data_spn, data), dim=1)
        else:
            raise ValueError(
                f"Input type {self.cfg.input_type} not supported for Transformer model"
            )
        model_output = self(input_to_model, attention_mask)
        if self.cfg.activation_function == "sigmoid":
            model_output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            model_output = m(model_output)

        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_spn = self.process_buckets_single_row_for_spn(
            nn_output=model_output, true=initial_data, buckets=buckets
        )
        if torch.isnan(model_output).any():
            logger.info("Nan in output")
            logger.info(model_output)
            exit()

        final_func_value = spn.evaluate(output_for_spn)
        all_nn_outputs.extend(model_output.detach().cpu().tolist())
        all_unprocessed_data.extend(initial_data.detach().cpu().tolist())
        all_outputs_for_spn.extend(output_for_spn.detach().cpu().tolist())
        for each_bucket in buckets:
            all_buckets[each_bucket].extend(
                buckets[each_bucket].detach().cpu().tolist()
            )

        if not self.cfg.no_log_loss:
            loss = -final_func_value
        else:
            loss = -torch.exp(final_func_value)

        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(model_output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss

        loss = loss.mean()
        return loss
