import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

import random
import numpy as np

# Experiment for associative recall
# 1. Generate random patterns
# 2. Train a network to store them
# 3. Test the network by giving it a cue and see if it can recall the correct pattern


def gen_single_example(associative_letters, associative_numbers, ):
    # permute the letters and numbers
    random.shuffle(associative_letters)
    random.shuffle(associative_numbers)
    # Combine the letters and numbers as "A1B2C5D3E4A"
    s = ''.join([i+j for i,j in zip(associative_letters, associative_numbers)])
    query_letter = random.choice(associative_letters)
    x = s+query_letter
    y = associative_numbers[associative_letters.index(query_letter)]
    return x, y
# 1. Generate random patterns
def data_gen(num_examples, gen_fn, *args):
    x = []
    y = []
    for i in range(num_examples):
        xi, yi = gen_fn(*args)
        x.append(xi)
        y.append(yi)
    return x, y

# self attention block
class Block(nn.Module):
    def __init__(self, embed_dim, max_len=11):
        super(Block, self).__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.c_attn = nn.Linear(embed_dim, embed_dim*3)
        self.register_buffer('mask', torch.tril(torch.ones(max_len, max_len)))
    def forward(self, x):
        T = x.size(1)
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        att = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        return y
class Conv_Block(nn.Module):
    def __init__(self, embed_dim, max_len=11):
        super(Conv_Block, self).__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.c_attn = nn.Linear(embed_dim, embed_dim*3)
        self.register_buffer('mask', torch.tril(torch.ones(max_len, max_len)))

        # normal conv on T*T attention matrix (after causal mask)
        # self.k_conv = nn.Conv2d(1, 1, (1, 5) , padding=(0, 2), bias=False) 
        # normal conv on T*T attention matrix (after causal mask)
        self.v_conv = nn.Conv2d(1, 1, (1, 5) , padding=(0, 2), bias=False) 
         # causal conv on T*d query matrix
        # self.q_conv = nn.Conv2d(1, 1, (3, 1), padding=(2, 0), bias=False) 
        # initialize the conv weight with very small value
        # self.k_conv.weight.data.fill_(1e-6)
        self.v_conv.weight.data.fill_(1e-6)
        # self.q_conv.weight.data.fill_(1e-6)
        
    def forward(self, x):
        T, d = x.size(1), x.size(2)
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        # causal on Q, use 3*1 kernel and discard the last two rows
        # q = self.q_conv(q.view(-1, 1, T, d))[:, :, :-2, :].view(-1, T, d)
        att = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(k.size(-1)))
        # causal conv on pre-softmax attention matrix
        # same as conv on K with same kernel.
        #att = att.masked_fill(self.mask[:T, :T] == 0, float('0.0'))
        #att = self.k_conv(att.view(-1, 1, T, T)).view(-1, T, T)
        att = att.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        # causal conv on post-softmax attention matrix
        # same as conv on V with reversed kernel
        att = self.v_conv(att.view(-1, 1, T, T)).view(-1, T, T)
        att = att.masked_fill_(self.mask[:T, :T] == 0, float('0.0'))
        y = att @ v
        return y

# a two layer attention network with embedding
class Network(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len=11, attn_layers=2, block=Block):
        super(Network, self).__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.att = nn.ModuleList([block(embed_dim, max_len) for _ in range(attn_layers)])
        self.head = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, x):
        x = self.embed(x)
        for layer in self.att:
            x = layer(x)
        x = self.head(x)
        return x
    
def encode_string(s, tokenizer):
    return [tokenizer[i] for i in s]
def evaluate(model, dl , criterion):
    model.eval()
    loss, acc = 0, 0
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)[:, -1, :]
        loss += criterion(y_pred, y).item()
        # when evaluating acc, we only need last prediction
        # print("y_pred", y_pred.argmax(dim=1))
        # print("y", y)
        acc += (y_pred.argmax(dim=1) == y).float().mean().item()
    print(f'Loss: {loss/len(dl)} Acc: {acc/len(dl)}')
    model.train()
    return loss, acc
def get_lr(epoch, max_lr = 0.001, min_lr = 0.0001, max_epoch = 1000, warm_up = 100, cosine = True):
    if epoch < warm_up:
        lr = (max_lr - min_lr) / warm_up * epoch + min_lr
    else:
        if cosine:
            lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos(np.pi * (epoch - warm_up) / (max_epoch - warm_up)))
        else:
            lr = min_lr + (max_lr - min_lr) * (1 - (epoch - warm_up) / (max_epoch - warm_up))
    return lr

class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper


    def forward(self, input_ids):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.lm_head(x)

        return logits

    
    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """Load pretrained weights from HuggingFace into model.
    
        Args:
            pretrained_model_name: One of
                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'
                            
        Returns:
            model: Mamba model with weights loaded
    
        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file
        
        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))
        
        
        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
        
        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)
        
        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)
        
        return model


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)

        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        output = self.mixer(self.norm(x)) + x

        return output
            

class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output

    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]
        
        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y


class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output
        


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    associative_letters = ['A', 'B', 'C', 'D', 'E' ]
    associative_numbers = ['0', '1', '2', '3', '4']
    vocab = associative_letters + associative_numbers
    tokenizer = dict(zip(vocab, range(len(vocab))))
    X, Y = data_gen(10000, gen_single_example, associative_letters, associative_numbers)
    X = np.array([encode_string(x, tokenizer) for x in X])
    Y = np.array([encode_string(y, tokenizer) for y in Y]).squeeze()
    # Y = np.eye(len(vocab))[Y].squeeze()
    print(X.shape, Y.shape)
    split = int(len(X)*0.8)
    X_train, X_test = X[:split], X[split:]
    Y_train, Y_test = Y[:split], Y[split:]
    X_train, X_test = torch.tensor(X_train), torch.tensor(X_test)
    Y_train, Y_test = torch.tensor(Y_train, dtype=torch.int64), torch.tensor(Y_test, dtype=torch.int64)
    train_ds = TensorDataset(X_train, Y_train)
    test_ds = TensorDataset(X_test, Y_test)
    train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
    test_dl = DataLoader(test_ds, batch_size=64)
    

    model = Network(len(vocab), 32, max_len=11, attn_layers=1, block=Conv_Block)
    # move the model to the GPU
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # first test the model
    evaluate(model, test_dl, criterion)
    # 2. Train a network to store them
    max_epoch = 300
    log_epoch = 20
    warm_up_epoch = 10
    for name, param in model.named_parameters():
        print(name, param.shape)
        if 'conv' in name:
            print(param)
    for epoch in range(max_epoch):
        lr = get_lr(epoch, max_epoch=max_epoch, warm_up=warm_up_epoch, max_lr=0.001, min_lr=0.00001)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        for x, y in train_dl:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)[:,-1,:]
            # print(x.shape, y.shape, y_pred.shape)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
        if epoch % log_epoch == 0:
            print(f'Epoch {epoch} loss: {loss.item()}')
            for name, param in model.named_parameters():
                if 'conv' in name:
                    print(name, param.detach().cpu().numpy().flatten())
    # 3. Test the network by giving it a cue and see if it can recall the correct pattern
    evaluate(model, test_dl, criterion)