import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.models.sequence.mrconv_filters import DenseFilter, DilatedFilter, FourierFilter, DilatedFourierFilter
from src.ops.fftconv import fftconv_ref

# Flash Attention
try:
    from flash_attn.ops.fused_dense import FusedDense
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
    FusedDense = None
    

class MRConvLayer(nn.Module):
    def __init__(
        self,
        d_model=256,
        l_max=1024,
        kernel_size=8,
        channels=1,
        bidirectional=False,
        filter_type='fourier',
        w_init='rand',
        depth_init=0,
        dropout=0.0,
        use_reparam=False,
        transposed=True,    # axis ordering (B, L, D) or (B, D, L)
        return_state=True,
        adaptive_sr=False,
        **filter_args
    ):
        super().__init__()
        
        # Hyperparams
        self.d_model = d_model
        self.l_max = l_max
        self.channels = channels
        self.bidirectional = bidirectional
        self.filter_type = filter_type
        self.return_state = return_state
        self.transposed = transposed
        self.adaptive_sr = adaptive_sr
        
        # FlashFFTConv
        self.flashfftconv = None
        
        # N_filter and dilations
        if filter_type == 'dilated':
            max_dilation = (l_max - kernel_size) / (kernel_size - 1) + 1
            lengths = [2**i for i in range(depth_init, math.floor(math.log2(max_dilation))+1)]
            lengths.append(int(max_dilation))
            max_kernel_size = kernel_size + (kernel_size - 1) * (lengths[-1] - 1)
            self.n_filters = len(lengths)
        elif filter_type in ['fourier', 'dense', 'dilated_fourier']:
            n_filters = math.log2(l_max) - math.log2(kernel_size)
            n_filters = math.floor(n_filters) if 2**math.ceil(n_filters) > l_max else math.ceil(n_filters)
            lengths = [kernel_size * 2**i for i in range(depth_init, n_filters)]
            lengths.append(l_max)
            max_kernel_size = lengths[-1]
            self.n_filters = len(lengths)
        elif filter_type in ['fourier_cat', 'dilated_fourier_cat']:
            n_filters = math.log2(l_max) - math.log2(kernel_size)
            n_filters = math.floor(n_filters) if 2**math.ceil(n_filters) > l_max else math.ceil(n_filters)
            lengths = [kernel_size * 2**(max(0, i-(depth_init+1)) + depth_init) for i in range(depth_init, n_filters)]
            max_kernel_size = lengths[-1]
            lengths.append(l_max - sum(lengths))
            self.n_filters = len(lengths)            
        else:
            raise NotImplementedError
        
        self.filters = nn.ModuleList()
        
        # Dense filter
        if use_reparam:
            self.n_filters = self.n_filters + 1
            self.filters.append(DenseFilter(d_model, kernel_size=l_max, channels=channels, bidirectional=bidirectional, **filter_args))
            
        # Low-rank filters
        print(f'==> MRConvLayer: depth={self.n_filters}, filter={self.filter_type}, n_filters={self.n_filters} ({lengths[0]}, {max_kernel_size}), reparam={use_reparam}, bidirectional={self.bidirectional}, w_init={w_init}')
        for i, l in enumerate(lengths):
            if filter_type == 'dilated':
                self.filters.append(DilatedFilter(d_model, kernel_size, channels=channels, bidirectional=bidirectional, dilation=l, **filter_args))
            elif filter_type == 'fourier' or filter_type == 'fourier_cat':
                self.filters.append(FourierFilter(d_model, kernel_size, channels=channels, bidirectional=bidirectional, kernel_length=l, **filter_args))
            elif filter_type == 'dilated_fourier' or filter_type == 'dilated_fourier_cat':
                self.filters.append(DilatedFourierFilter(d_model, kernel_size, channels=channels, bidirectional=bidirectional, kernel_length=l, **filter_args))
            elif filter_type == 'dense':
                self.filters.append(DenseFilter(d_model, kernel_size=l, channels=channels, bidirectional=bidirectional, **filter_args))
            elif filter_type == 'mlp':
                if l < filter_args['mlp_cutoff']:
                    self.filters.append(DenseFilter(d_model, kernel_size=l, channels=channels, bidirectional=bidirectional, **filter_args))
                else:
                    self.filters.append(MLPFilter(d_model, kernel_size=l, channels=channels, bidirectional=bidirectional, **filter_args))
            else:
                raise NotImplementedError
            
        # Linear aggregation
        channels = 2 * channels if self.bidirectional and filter_type in ['fourier_cat', 'dilated_fourier_cat'] else channels
        if w_init == 'rand':
            w = torch.empty(channels, d_model, self.n_filters).uniform_(-1., 1.) * math.sqrt(2.0 / (2*(self.n_filters)))
            self.w = nn.Parameter(w)
        elif w_init == 'decay':
            weight_list = []
            multiplier = torch.linspace(1, 3, d_model).view(1, -1, 1).expand(channels, d_model, 1)
            for i in range(self.n_filters):
                weight_list.append(multiplier ** (self.n_filters - 1 - i))
            w = torch.cat(weight_list, dim=-1)
            self.w = nn.Parameter(w)
            
        # Skip connection (D matrix from SSM equation)
        self.D = nn.Parameter(torch.randn(self.channels, self.d_model))
        
        # Activation
        self.activation = nn.GELU()
        
        # Dropout
        self.dropout = nn.Dropout1d(dropout)
        
        # Mixing layer
        self.linear = nn.Linear(d_model*self.channels, d_model*2)
        
        # Gated linear unit
        self.glu = nn.GLU(dim=-1)
        
        # Reparamterized params
        self.reparam_kernel = None
        self.reparam_bias = None
        
        
    def reparameterize(self):
        """ Reparameterize multi-resolution branches + BN into single kernel """
        
        assert self.filter_type == 'fourier'
        
        reparam_kernel = 0.
        reparam_bias = 0.
        
        for i in range(self.n_filters):
            self.filters[i].reparameterize(self.l_max)
            reparam_kernel += self.filters[i].reparam_kernel * self.w[:, i:i+1]
            reparam_bias += self.filters[i].reparam_bias * self.w[:, i:i+1]
            
        self.reparam_kernel = reparam_kernel
        self.reparam_bias = reparam_bias
            
        
    def multi_resolution_convolution_concat(self, x):
        """ Compute multiresolution convolution with concat and no bn
        
        Args:
            x: tensor of shape (B, d_model, L)
            
        Returns:
            output: tensor of shape (B, d_model, L, M)
        """
        assert self.filter_type in ['fourier_cat', 'dilated_fourier_cat'] 
        
        # Weight
        w = self.w.reshape(-1, self.n_filters)
        
        # Reparam kernel
        reparam_kernel = []
        for i in range(self.n_filters):
            reparam_kernel.append(w[:,i].unsqueeze(-1) * self.filters[i].get_kernel())
        reparam_kernel = torch.concat(reparam_kernel, dim=-1)
        
        # Convolution
        if self.bidirectional:
            k_bi = torch.split(reparam_kernel, reparam_kernel.shape[0]//2, dim=0)
            y = fftconv_ref(x, k_bi[0], self.D, dropout_mask=None, gelu=False, k_rev=k_bi[1])
        else:
            y = fftconv_ref(x, reparam_kernel, self.D, dropout_mask=None, gelu=False)
            
        return rearrange(y, 'b d l -> b l d')
        
    
    def multi_resolution_convolution_linear(self, x):
        """ Compute multiresolution convolution without bn
        
        Args:
            x: tensor of shape (B, d_model, L)
            
        Returns:
            output: tensor of shape (B, d_model, L, M)
        """
        assert self.filter_type != 'dilated'
        
        # Weight
        weight = self.w.reshape(-1, self.n_filters)
        if self.bidirectional:
            weight = weight.repeat(2, 1)
        
        # Reparameterised kernel
        reparam_kernel = 0.0
        for i in range(self.n_filters):
            kernel = self.filters[i].get_kernel()
            kernel = weight[...,i:i+1] * F.pad(kernel, (0, self.l_max-kernel.shape[-1]))
            reparam_kernel += kernel
            
        # Bias
        bias = torch.zeros(self.d_model * self.channels).to(x.device)
            
        # Convolution
        if self.bidirectional:
            k_bi = torch.split(reparam_kernel, reparam_kernel.shape[0]//2, dim=0)
            y = fftconv_ref(x, k_bi[0], bias, dropout_mask=None, gelu=False, k_rev=k_bi[1])
        else:
            y = fftconv_ref(x, reparam_kernel, bias, dropout_mask=None, gelu=False)
            
        return rearrange(y, 'b (c d) l -> b c d l', c=self.channels)
        
        
    def multi_resolution_convolution_bn(self, x):
        """ Compute multiresolution convolution.
        
        Args:
            x: tensor of shape (B, d_model, L)
            
        Returns:
            output: tensor of shape (B, d_model, L, M)
        """
        # Output
        output = 0.0
        sr_factor = x.shape[-1] / self.l_max if self.adaptive_sr else 1
        
        # Compute multi-resolution convolution
        for i in range(self.n_filters):    
            
            # Convolution
            if self.flashfftconv is not None:
                k = self.filters[i].get_kernel(sr_factor=sr_factor)
                x_conv = x.to(torch.bfloat16).contiguous()
                k_conv = k
                y = self.flashfftconv(x_conv, k_conv)
                y = y.to(torch.float32) 
            else:
                y = self.filters[i](x, sr_factor=sr_factor)
            
            # Linear summation of convolutions
            output += self.w[:, :, i:i+1] * y
            
        # Skip connection (D term in state space equation)
        output = output + torch.einsum('bdl,cd->bcdl', x, self.D)
        
        # Rearrange to flatten shapes
        output = rearrange(output, 'b c d l -> b l (c d)')
            
        return output
    
    
    def forward(self, x, *args, **kwargs):
        """ Multiresolution block forward
        
        Args:
            x: shape (B D L) if self.transposed else (B L D)
            
        Returns:
            output: same as input
        """
        
        if not self.transposed: x = x.transpose(-1, -2)
        
        input = x
        
        # Multi-resolution convolution
        # ==> Use reparameterized kernel params
        if self.training is False and self.reparam_kernel is not None:
            x = fftconv_ref(x, self.reparam_kernel, torch.zeros(self.d_model).to(x.device), dropout_mask=None, gelu=False)
            x = x + self.reparam_bias.reshape(1, -1, 1)
            x = rearrange(y, 'b (c d) l -> b c d l', c=self.channels)
        # ==> Concatenate filters
        if self.filter_type in ['fourier_cat', 'fourier_dilated_cat']:    
            x = self.multi_resolution_convolution_concat(x)
        # ==> Use branches
        else:
            x = self.multi_resolution_convolution_bn(x)     
        
        # Activation
        x = self.dropout(self.activation(x))
        
        # Channel mixing
        x = self.linear(x)
        
        # Gated linear unit
        x = self.glu(x)
        
        if self.transposed: x = x.transpose(-1, -2)
        
        if self.return_state:
            return x, None
        return x
    
    @property
    def d_state(self):
        return self.d_model

    @property
    def d_output(self):
        return self.d_model
    
    