import torch
import torch.nn as nn
import math

class RotaryPositionEncoder(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        self.register_buffer('sin', freqs.sin(), persistent=False)
        self.register_buffer('cos', freqs.cos(), persistent=False)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        sin = self.sin[:seq_len, :].repeat_interleave(2, dim=-1)
        cos = self.cos[:seq_len, :].repeat_interleave(2, dim=-1)
        return (x * cos.unsqueeze(0)) + (torch.roll(x, shifts=1, dims=-1) * sin.unsqueeze(0))


class TransformerEncoderLayerWithRoPE(nn.TransformerEncoderLayer):
    def __init__(self, d_model=1024, dropout=0.1):
        super().__init__(d_model, d_model // 64, d_model * 4, dropout, batch_first=True)
        self.rope = RotaryPositionEncoder(d_model)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs):
        src2 = self.self_attn(self.rope(src), self.rope(src), self.rope(src), attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerRoPESemanticModel(nn.Module):
    def __init__(self, d_model, n_layers, dropout=0.1):
        super().__init__()
        encoder_layer = TransformerEncoderLayerWithRoPE(d_model, dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.mask_vector = nn.Parameter(torch.randn(d_model))
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, label=None, head_weight=None): 
        r"""
        Args:
            x: (batch_size, seq_length, feature_size)
            label: (batch_size, seq_length) ranging from 0 to codebook_size - 1
            head_weight: (feature_size, codebook_size)
        """
        
        if label is not None:
            # add mask to x. Randomly sample the starting positions to be masked with a probabilityof 0.065 and mask the subsequent 5 time steps,
            #  The masked spans may overlap. There should be at least one masked token in each sample. 
            mask = torch.full(x.shape[:2], False, dtype=torch.bool)
            for i in range(x.size(0)):  # 遍历每个样本
                probability = torch.rand(x.size(1))
                mask_indices = torch.where(probability < 0.065)[0]
                for idx in mask_indices:
                    mask[i, idx:idx+5] = True  
                if not mask[i].any():
                    idx = torch.randint(0, x.size(1) - 5, (1,))
                    mask[i, idx:idx+5] = True
            x[mask] = self.mask_vector.type_as(x)
            label[~mask] = -100
            # calculate the output and loss
            output = self.encoder(x)
            logits = torch.matmul(output, head_weight)
            mlm_loss = self.loss_fn(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
            return output, mlm_loss
        else:
            x = self.encoder(x)
            return x, None
        

if __name__ == '__main__':
    # 示例使用
    d_model = 512
    nhead = 8
    num_layers = 6
    encoder = TransformerRoPESemanticModel(d_model, num_layers)
    x = torch.rand(32, 10, d_model)  # 修改为 batch first 形状 (batch_size, seq_length, feature_size)
    output = encoder(x)
    print(output.shape)  # torch.Size([32, 10, 512])



