from functools import partial

import einops
import torch
from custommodules.init import init_xavier_uniform_zero_bias
from custommodules.layers import ContinuousSincosEmbed
from custommodules.transformer import PerceiverBlock, PrenormBlock
from custommodules.vit import DitBlock
from torch import nn

from models.base.single_model_base import SingleModelBase
from optimizers.param_group_modifiers.exclude_from_wd_by_name_modifier import ExcludeFromWdByNameModifier
from torch_geometric.utils import to_dense_batch


class RansTransformer(SingleModelBase):
    def __init__(self, dim, depth, num_attn_heads, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.depth = depth
        self.num_attn_heads = num_attn_heads

        # set ndim
        _, ndim = self.input_shape
        self.static_ctx["ndim"] = ndim

        # pos_embed
        self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=ndim)

        # transformer
        if "condition_dim" in self.static_ctx:
            block_ctor = partial(DitBlock, cond_dim=self.static_ctx["condition_dim"])
        else:
            block_ctor = PrenormBlock
        self.transformer_blocks = nn.ModuleList([
            block_ctor(dim=dim, num_heads=num_attn_heads)
            for _ in range(depth)
        ])

        # output shape
        self.output_shape = (None, dim)

    def forward(self, mesh_pos, mesh_edges, batch_idx):
        x = self.pos_embed(mesh_pos)
        x, mask = to_dense_batch(x, batch_idx)
        if torch.all(mask):
            mask = None
        else:
            raise NotImplementedError("mask needs to be returned somehow and passed to latent model")

        # transformer
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, attn_mask=mask)

        return x
