import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ViTBlock(nn.Module):
    def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        # Transformer encoder layer
        self.encoder_layer = TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            batch_first=True  # Input shape: (batch_size, seq_length, feature_dim)
        )
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        x = self.transformer_encoder(x)
        return x

class TwoLayerViT(nn.Module):
    def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
        super(TwoLayerViT, self).__init__()
        # Define two successive ViT blocks
        self.vit_block1 = ViTBlock(feature_dim, num_heads, mlp_dim, dropout)
        self.vit_block2 = ViTBlock(feature_dim, num_heads, mlp_dim, dropout)

    def forward(self, x):
        x = self.vit_block1(x)
        x = self.vit_block2(x)
        return x

# Define the number of features, number of attention heads, and MLP dimension
feature_dim = 768
num_heads = 12
mlp_dim = 3072

# Instantiate the model
model = TwoLayerViT(feature_dim, num_heads, mlp_dim)

# Assuming that we have a batch of fused feature vectors of size (batch_size, sequence_length, feature_dim)
batch_size = 10
sequence_length = 1
fused_features = torch.randn(batch_size, sequence_length, feature_dim)

# Pass the features through the model
transformed_features = model(fused_features)

print(f'Input shape: {fused_features.shape}')
print(f'Transformed shape: {transformed_features.shape}')