import math

import torch
import torch.nn as nn

from workbench.gpt_model.modules.embedding import PosEmbedding
from workbench.gpt_model.modules.transformer_block import TransformerBlock
from workbench.gpt_model.modules.norm import RMSNorm

class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.embed = PosEmbedding(config.vocab_size, config.model_dim, config.max_len, config.pos_embedding,
                                      config.rel_pos_enc, config.initializer_range)

        if config.tie_embed:
            self.output = self.embed.embed_seq
        else:
            self.output = nn.Linear(config.model_dim, config.vocab_size)

        self.decoder = nn.ModuleList()
        for _ in range(config.n_layers):
            self.decoder.append(TransformerBlock(config, cross_attn=False, causal_attn=True))


        if config.rms_norm:
            self.norm = RMSNorm(config.model_dim, eps=config.ln_eps, add_unit_offset=config.add_unit_offset)
        else:
            self.norm = nn.LayerNorm(config.model_dim, eps=config.ln_eps, bias=config.use_bias, elementwise_affine=config.learn_ln)


        self.initialize(config.initializer_range, config.vocab_size, config.last_zero)

    def initialize(self, initializer_range, vocab_size, last_zero):

        for n, p in self.named_parameters():
            if 'bias' in n:
                nn.init.zeros_(p)
            elif 'norm' in n or 'ln' in n:
                continue
            elif p.shape == torch.Size([1]):
                continue
            elif 'out' in n and last_zero:
                nn.init.zeros_(p)
            elif 'embed_seq' in n:
                if initializer_range:
                    nn.init.normal_(p, mean=0.0, std=initializer_range)
                else:
                    nn.init.normal_(p, mean=0.0, std=1.0/math.sqrt(vocab_size))
            else:
                if initializer_range:
                    nn.init.normal_(p, mean=0.0, std=initializer_range)
                else:
                    nn.init.xavier_uniform_(p)




    def make_decoder_mask(self, trg_embed, trg_len):
        mask = torch.arange(trg_embed.size()[1], device=trg_embed.device).expand(trg_embed.shape[:2]) < trg_len.unsqueeze(1)
        mask = mask.unsqueeze(-1)

        causal_mask = torch.triu(torch.ones((1, trg_embed.size()[1], trg_embed.size()[1]), dtype=torch.bool, device=trg_embed.device), diagonal=1)
        causal_mask = causal_mask == 0
        decoder_mask = mask & causal_mask

        assert isinstance(decoder_mask, torch.BoolTensor) or isinstance(decoder_mask, torch.cuda.BoolTensor)

        return torch.bitwise_not(decoder_mask)


    def forward(self, trg_seq, trg_len=None):

        hidden_states = self.embed(trg_seq)

        if trg_len is None:
            attn_mask = None
        else:
            attn_mask = self.make_decoder_mask(trg_seq, trg_len)

        for i in range(len(self.decoder)):
            layer = self.decoder[i]
            hidden_states = layer(hidden_states=hidden_states, attn_mask=attn_mask)

        logits = self.output(self.norm(hidden_states))

        return logits


