import math
import torch 
from torch import nn
from mamba_ssm import Mamba


class MambaClf(nn.Module):

    def __init__(self, n_tokens, d_channels, d_state, layers=1, d_conv=4, d_expand=2, noutputs=1):
        super(MambaClf, self).__init__()
        self.d_channels = d_channels
        self.embed_layer= nn.Embedding(n_tokens, d_channels)
        self.mamba_block = Mamba(d_model=d_channels, d_state=d_state, d_conv=d_conv, expand=d_expand)
        self.encoder = torch.nn.ParameterList([self.mamba_block for l in range(layers)])
        self.decoder= nn.Linear(d_channels, noutputs, bias=False)
        self.sigmoid = nn.Sigmoid()

    def init_weights(self, weight_init= 10):

        self.embed_layer.weight.data.uniform_(-weight_init, weight_init)
        self.decoder.weight.data.uniform_(-weight_init, weight_init)
        for mamba_layer in self.encoder:
            for key in mamba_layer.state_dict().keys():
                if 'weight' in key:
                    mamba_layer.state_dict()[key].data.uniform_(-weight_init, weight_init)
                elif 'bias' in key:
                    mamba_layer.state_dict()[key].data.zero_()
    
    def init_xavuni(self):

        self.embed_layer.weight.data.normal_(mean=0, std=1)
        torch.nn.init.xavier_uniform_(self.decoder.weight.data)
        for mamba_layer in self.encoder:
            for key in mamba_layer.state_dict().keys():
                if 'weight' in key:
                    if mamba_layer.state_dict()[key].dim()>1:
                        torch.nn.init.xavier_uniform_(mamba_layer.state_dict()[key].data)
                elif 'bias' in key:
                    mamba_layer.state_dict()[key].data.zero_()

    def init_gauss_weights(self, std_init=10, inp_init=1, dec_init=1):
        
        self.embed_layer.weight.data.normal_(mean=0, std=inp_init)
        self.decoder.weight.data.normal_(mean=0, std=dec_init)
        for mamba_layer in self.encoder:
            for key in mamba_layer.state_dict().keys():
                if 'weight' in key:
                    mamba_layer.state_dict()[key].data.normal_(mean=0, std=std_init)
                elif 'bias' in key:
                    mamba_layer.state_dict()[key].data.zero_()

    def init_xavnormal(self):
        self.embed_layer.weight.data.normal_(mean=0, std=1)
        self.decoder.weight.data.normal_(mean=0, std=1)        
        
        for mamba_layer in self.encoder:
            for key in mamba_layer.state_dict().keys():
                if 'weight' in key:
                    if mamba_layer.state_dict()[key].dim()>1:
                        torch.nn.init.xavier_normal_(mamba_layer.state_dict()[key].data)
                elif 'bias' in key:
                    mamba_layer.state_dict()[key].data.zero_()    

    def forward(self, src, lengths):
        src = self.embed_layer(src) * math.sqrt(self.d_channels)
        src = src.transpose(0, 1)

        for mamba_layer in self.encoder:
            src = mamba_layer(src)
        
        slots = src.size(1)
        out_flat= src.view(-1, self.d_channels)
        out_idxs= [(i*slots)+lengths[i].item() -1 for i in range(len(lengths))]
        
        output = out_flat[out_idxs]
        decoded = self.decoder(output)
        decoded = self.sigmoid(decoded)

        return decoded
