import os
import json
from transformers.models.llama.modeling_llama import *
from .streaming_kernel import streaming_forward, streaming_forward2
from .pit_sparse_flash_attention_v2 import pit_sparse_flash_attention_forward
from .block_sparse_flash_attention import block_sparse_flash_attention_forward
from .snap_kv import *

last_q = 64
arange = torch.arange(last_q, device="cuda")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]

def init_minference_parameters(self):
    config = self.config.to_dict()

    self.topk = config.get("topk", -1)
    self.topk_from_layer = config.get("topk_from_layer", 0)
    self.kv_cache_compressed = config.get("kv_cache_compressed", False)
    self.kv_cache_compressed_slash = config.get("kv_cache_compressed_slash", False)
    self.kv_cache_compressed_h2o = config.get("kv_cache_compressed_h2o", False)
    self.kv_cache_compressed_v4 = config.get("kv_cache_compressed_v4", False)

    # self.n_init = config.get("n_init", 128)
    # self.n_local = config.get("n_local", 3968)
    self.topk_ratio = config.get("topk_ratio", 0.1)
    self.block_size = config.get("block_size", 32)

    self.ne_inf = None
    self.topk_dims_file_path = config.get("topk_dims_file_path", "")
    if os.path.exists(self.topk_dims_file_path):
        self.best_pattern = {int(ii): jj for ii, jj in json.load(open(self.topk_dims_file_path))[self.layer_idx].items()}
    self.vertical, self.slash = None, None

def sum_all_diagonal_matrix(mat: torch.tensor): 
    b, h, n, m = mat.shape
    zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
    mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
    sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
    return sum_diags[:,:,1:]

def gather(t, dim, i):
    """A broadcasting version of torch.gather."""
    dim += (dim < 0) * t.ndim
    return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))

def gather_qkv(q, k, v, attention_mask):
    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) + attention_mask
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
    attn_output = torch.matmul(attn_weights, v)
    return attn_output

def search_pattern(q, k, attention_mask, head, idx=None):
    q_len = q.shape[2]    

    def vertical_and_slash(vertical_size, slash_size):
        last_q = 64
        q_len = q.shape[2]
        qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
        qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(128) + attention_mask[:,:,qk_idxs]
        qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        vertical = qk.sum(-2, keepdim=True)
        vertical[...,:30] = 10000
        vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices

        slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
        slash[...,-30:] = 10000
        slash_topk = slash
        slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
        slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
        
        est_attn = torch.ones_like(attn_weights)
        dim = 3
        est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn + slash
        
        est_attn = (est_attn > 0).float()
        est_attn = torch.tril(est_attn)
        attn_weights_x = attn_weights * est_attn
        res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res3
    
    def stream_llm(vertical_size, slash_size):
        q_len = q.shape[2]

        mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
        mask[:,:vertical_size] = 1
        mask = mask.unsqueeze(0).unsqueeze(1)
        
        est_attn = torch.tril(mask)
        attn_weights_x = attn_weights * est_attn
        res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res3

    def retrieval_head(topk_ratio, slash_size=None):
        block_num = (q_len -1) // 32 + 1
        block_q = torch.zeros(1,1,block_num * 32,128).to(q)
        block_q[:,:,:q_len] = q
        block_q = block_q.reshape(1,1,block_num,32,-1).mean(-2)
        block_k = torch.zeros(1,1,block_num * 32,128).to(k)
        block_k[:,:,:q_len] = k
        block_k = block_k.reshape(1,1,block_num,32,-1).mean(-2)

        qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
        est_attn = torch.ones_like(qk)
        block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
        
        dim = 3
        est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,32,1,32).reshape(1,1,block_num * 32, block_num * 32)[...,:q_len,:q_len]
        est_attn = torch.tril(est_attn)

        attn_weights_x = attn_weights * est_attn
        res2 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res2

    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(128) + attention_mask
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)   
    best_s, best_v, best_score, best_ty = 0, 0, 0, ""
    all_info = []
    for ty, fc in [("stream_llm", stream_llm), ("vertical_and_slash", vertical_and_slash), ("retrieval_head", retrieval_head)]:
        if ty == "stream_llm":
            vs_list = [(100, 800)]
        elif ty == "vertical_and_slash":
            vs_list = [(30, 800), (100, 750), (500, 700), (3500, 100)]
        else:
            vs_list = [(8, 1), (7, 1), (6, 1)]
        for v_size, s_size in vs_list:
            score = fc(v_size, s_size)
            score = score.item()
            all_info.append([ty, v_size, s_size, score])
            if score > best_score:
                best_score = score
                best_s, best_v = s_size, v_size
                best_ty = ty
    v_size, s_size = 7, 1
    while best_score < 0.7 and v_size > 5:
        score = retrieval_head(v_size, s_size)
        if score > best_score:
            best_score = score
            best_s, best_v = s_size, v_size
            best_ty = "retrieval_head"
        v_size -= 1
    print(head, best_ty, best_v, best_s, best_score)
    return all_info

def search_pattern_v2(q, k, v, head):
    q_len = q.shape[2]
    def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
        vertical_size, slash_size  = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
        last_q = 64
        qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
        qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)    
        qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        vertical = qk.sum(-2, keepdim=True)
        vertical[...,:30] = torch.inf
        vertical_topk = torch.topk(vertical, vertical_size, -1).indices

        slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
        slash[...,-30:] = torch.inf
        slash_topk = slash
        slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices

        return pit_sparse_flash_attention_forward(q, k, v, vertical_topk, slash)
    def dense(q, k, v, vertical_size=None, slash_size=None):
        return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, 128)
    def retrieval_head_kernel(q, k, v, vertical_size=None, slash_size=None):
        topk = 100
        return block_sparse_flash_attention_forward(q, k, v, topk)

    best_s, best_v, best_score, best_ty = 0, 0, float("inf"), ""
    bsz = q.shape[0]
    all_info = []
    ref = dense(q, k, v)
    for ty, fc in [("stream_llm", streaming_forward), ("vertical_and_slash", vertical_and_slash_kernel), ("retrieval_head", retrieval_head_kernel)]:
        if ty == "stream_llm":
            vs_list = [(100, 800)]
        elif ty == "vertical_and_slash":
            vs_list = [(30, 800), (100, 800), (100, 750), (500, 700), (3500, 100), (1000, 4096)]
        else:
            vs_list = [(10, 1)]
        for v_size, s_size in vs_list:
            score = fc(q, k, v, v_size, s_size)
            # import ipdb;ipdb.set_trace()
            # delta = (ref - score).abs().sum()
            delta = ((ref - score).abs() > 5e-3).sum()
            score = delta.item()
            all_info.append([ty, v_size, s_size, score])
            if score < best_score:
                best_score = score
                best_s, best_v = s_size, v_size
                best_ty = ty
    print(head, best_ty, best_v, best_s, best_score)
    return all_info
    # return [best_ty, best_v, best_s, best_score]

def shift_matrix(mat): 
    b, h, _, n = mat.shape
    zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
    mat_strided = mat_padded.as_strided((1, 1, n, n + 2 * n), (1, n * (2 * n + n), 2 * n + n - 1, 1)) # Change the strides
    return mat_strided[...,2 * n-1:-1]

def repeat(self, q, k, v, attention_mask):
    q_len = q.shape[2]
    if q_len == 1:
        return gather_qkv(q, k, v, attention_mask)
    qk = torch.matmul(q[:,:,-1:,:], k.transpose(2, 3)) / math.sqrt(self.head_dim)
    qk = qk.repeat(1,1,q_len, 1)
    qk = shift_matrix(qk) + attention_mask
    attn_weights = nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
    attn_output = torch.matmul(attn_weights, v)
    return attn_output

def gather_last_q_vertical_slash_topk_v4(self, q, k, v, head_id):
    kv_seq_len = k.size(2)

    def vertical_and_slash(attn_weights, vertical_size, slash_size):
        last_q = 64
        q_len = q.shape[2]
        # vertical_size, slash_size = int(vertical_size / 30000 * (q_len)), int(slash_size / 30000 * (q_len))
        vertical_size, slash_size  = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
        qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
        qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(self.head_dim) + attention_mask[:,:,qk_idxs]
        qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        vertical = qk.sum(-2, keepdim=True)
        vertical[...,:30] = -self.ne_inf
        vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices

        slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
        slash[...,-30:] = -self.ne_inf
        slash_topk = slash
        slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
        slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
        
        est_attn = torch.ones_like(attn_weights)
        dim = 3
        est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn + slash
        
        est_attn = (est_attn > 0).float()
        est_attn = torch.tril(est_attn)
        # print("vertical_and_slash", est_attn.sum() / (0.5 * q_len ** 2))
        est_attn = (est_attn == 0).int() * self.ne_inf
        attn_weights = attn_weights + est_attn
        if self.kv_cache_compressed_v4:
            self.vertical = torch.topk(vertical, vertical_size * 4, -1).indices
            self.slash = (torch.topk(slash_topk, slash_size * 4, -1).indices - (q_len - 1)).unsqueeze(2)
        return attn_weights

    def stream_llm(attn_weights, vertical_size, slash_size):
        q_len = q.shape[2]
        vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
        mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
        mask[:,:vertical_size] = 1
        mask = mask.unsqueeze(0).unsqueeze(1)
        
        est_attn = torch.tril(mask)
        # print("stream_llm", est_attn.sum() / (0.5 * q_len ** 2))
        est_attn = (est_attn == 0).int() * self.ne_inf
        attn_weights = attn_weights + est_attn
        if self.kv_cache_compressed_v4:
            self.vertical = torch.Tensor(list(range(vertical_size * 4))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
            self.slash = torch.Tensor(list(range(-slash_size * 4, 1))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
        return attn_weights

    def retrieval_head(attn_weights, topk_ratio, slash_size=None, block_size=8):
        block_num = (q_len -1) // block_size + 1
        block_q = torch.zeros(1,1,block_num * block_size,128).to(q)
        block_q[:,:,:q_len] = q
        block_q = block_q.reshape(1,1,block_num,block_size,-1).mean(-2)
        block_k = torch.zeros(1,1,block_num * block_size,128).to(k)
        block_k[:,:,:q_len] = k
        block_k = block_k.reshape(1,1,block_num,block_size,-1).mean(-2)

        qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
        est_attn = torch.ones_like(qk)
        block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
        
        dim = 3
        est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,block_size,1,block_size).reshape(1,1,block_num * block_size, block_num * block_size)[...,:q_len,:q_len]
        est_attn = torch.tril(est_attn)
        # print("retrieval_head", est_attn.sum() / (0.5 * q_len ** 2))
        est_attn = (est_attn == 0).int()
        attn_weights = attn_weights + est_attn
        return attn_weights
    
    def dialted(q,k,v, type):
        q_len = q.shape[2]        
        n_init = min(1024, q_len)
        vertical_topk = torch.arange(0, n_init, device=q.device)[None, None, None, :]

        slash = torch.arange(0, q_len, device=q.device)
        if type == 'dilated1':
            # 8k local with 1 interval
            slash = slash[-8192::2][None, None, :]
        elif type == 'dilated2':
            # 2k dense local + 4k local with 1 interval
            slash = torch.cat([slash[-2048:], slash[-6144:-2048:2]], 0)[None, None, :]

        slash = (q_len - 1) - slash
        return pit_sparse_flash_attention_forward(q, k, v, vertical_topk, slash)

    def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
        vertical_size, slash_size  = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
        last_q = 64
        qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
        qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK.to(q.device), qk[:, :, :, -last_q:], -torch.inf)    
        qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        vertical = qk.sum(-2, keepdim=True)
        vertical[...,:30] = torch.inf
        vertical_topk = torch.topk(vertical, vertical_size, -1).indices

        slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
        slash[...,-100:] = torch.inf
        slash_topk = slash
        slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices

        return pit_sparse_flash_attention_forward(q, k, v, vertical_topk, slash)

    def vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size):
        if "vs" in self.__dict__:
            vertical_topk, slash = self.vs
        else:
            vertical_size, slash_size  = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
            last_q = 64
            qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
            qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)    
            qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
            vertical = qk.sum(-2, keepdim=True)
            vertical[...,:30] = torch.inf
            vertical_topk = torch.topk(vertical, vertical_size, -1).indices

            slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
            slash[...,-30:] = torch.inf
            slash_topk = slash
            slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
            self.vs = vertical_topk, slash

        return pit_sparse_flash_attention_forward(q, k, v, vertical_topk, slash)
    def dense(q, k, v, vertical_size=None, slash_size=None):
        return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
    def retrieval_head_kernel(q, k, v, vertical_size=None, slash_size=None):
        topk = 100
        return block_sparse_flash_attention_forward(q, k, v, topk)

    q_len = q.shape[2]
    bsz = q.shape[0]

    if self.config.to_dict().get("dilated1", False):
        return dialted(q, k, v, 'dilated1')
    if self.config.to_dict().get("dilated2", False):
        return dialted(q, k, v, 'dilated2')

    ty, vertical_size, slash_size, _ = self.best_pattern[head_id]
    if slash_size == 4096:
        slash_size = 6096

    if self.config.to_dict().get("static_pattern", False):
        return vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size)
    if self.config.to_dict().get("vs_only", False):
        # return vertical_and_slash_kernel(q, k, v, 1024, 6096)
        return vertical_and_slash_kernel(q, k, v, vertical_size, slash_size)
    
    if q_len == 1:
        return dense(q, k, v)
    
    # if ty == "retrieval_head":
    #     vertical_size, slash_size = 1000, 6096

    fc = {
        "stream_llm": streaming_forward,
        "vertical_and_slash": vertical_and_slash_kernel,
        "retrieval_head": retrieval_head_kernel,
    }[ty]
    return fc(q, k, v, vertical_size, slash_size)

def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    return (q * cos) + (rotate_half(q) * sin)

def minference_forward():
    def forward(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        past_key_value,
        output_attentions,
        use_cache,
        **kwargs,
    ):
        self.init_minference_parameters()
        self.ne_inf = torch.finfo(hidden_states.dtype).min

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if self.topk != -1 and self.layer_idx >= self.topk_from_layer:
            assert query_states.size(1) == key_states.size(1) == value_states.size(1)
            output = torch.empty_like(query_states)
            for head in range(query_states.size(1)):
                q = query_states[:, head, :, :].unsqueeze(1)
                k = key_states[:, head, :, :].unsqueeze(1)
                v = value_states[:, head, :, :].unsqueeze(1)
                output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
            
            attn_output = output.transpose(1, 2).contiguous()
            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
            attn_output = self.o_proj(attn_output)
            return attn_output, None, past_key_value
        
        else:
            output = torch.empty_like(query_states)
            for head in range(query_states.size(1)):
                q = query_states[:, head, :, :].unsqueeze(1)
                k = key_states[:, head, :, :].unsqueeze(1)
                v = value_states[:, head, :, :].unsqueeze(1)
                if is_flash_attn_2_available():
                    attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q.shape[2], self.head_dim)
                else:
                    attn_output = gather_qkv(q, k, v, attention_mask)
                output[:, head:head + 1] = attn_output
            attn_output = output.transpose(1, 2).contiguous()
            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
            attn_output = self.o_proj(attn_output)

            return attn_output, None, past_key_value

    return forward

def minference_wo_cache_forward():
    def forward(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        past_key_value,
        output_attentions,
        use_cache,
        **kwargs,
    ):
        self.init_minference_parameters()
        self.ne_inf = torch.finfo(hidden_states.dtype).min

        bsz, q_len, hidden_dim = hidden_states.size()
        kv_seq_len = q_len
        if past_key_value is not None:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    
        cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
        cache_kwargs = {"sin": sin, "cos": cos}
    
        attn_out = torch.empty_like(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
        act_num_heads = self.num_heads // self.num_key_value_groups
        k = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
        v = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
        part_k, part_v = None, None
        for head in range(self.num_heads):
            part_q = F.linear(hidden_states, self.q_proj.weight.view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
            part_q = apply_rotary_pos_emb_single(part_q.transpose(1, 2), cos, sin, position_ids)

            if head % self.num_key_value_groups == 0:
                part_k = F.linear(hidden_states, self.k_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2)
                part_v = F.linear(hidden_states, self.v_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2).transpose(1, 2)
                part_k = apply_rotary_pos_emb_single(part_k.transpose(1, 2), cos, sin, position_ids)
                k[:,head // self.num_key_value_groups] = part_k.cpu()
                v[:,head // self.num_key_value_groups] = part_v.cpu()
                part_k, part_v = past_key_value.get(part_k, part_v, self.layer_idx, head // self.num_key_value_groups, cache_kwargs)

            if self.topk != -1 and self.layer_idx >= self.topk_from_layer:
                part_o = self.gather_last_q_vertical_slash_topk_v4(part_q, part_k, part_v, head)
            else:
                part_o = flash_attn_func(part_q, part_k, part_v.transpose(1, 2), 0.0, softmax_scale=None, causal=True).view(bsz, part_q.shape[1], self.head_dim)
            attn_out[:, :, head, :] = part_o
        
        past_key_value.update(k, v, self.layer_idx, cache_kwargs)
        torch.matmul(attn_out.view(bsz, q_len, hidden_dim), self.o_proj.weight.T, out=hidden_states)
        torch.cuda.empty_cache()
        return (hidden_states, None, past_key_value)

    return forward

def minference_with_snapkv_forward():
    def forward(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        past_key_value,
        output_attentions,
        use_cache,
        **kwargs,
    ):
        self.init_minference_parameters()
        self.ne_inf = torch.finfo(hidden_states.dtype).min

        init_snapkv(self)

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            
            if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
                if self.kv_seq_len != 0:
                    kv_seq_len += self.kv_seq_len
                else:
                    kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
                self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
                key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
                past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
            else:
                self.kv_seq_len += q_len
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        if self.topk != -1 and self.layer_idx >= self.topk_from_layer:
            assert query_states.size(1) == key_states.size(1) == value_states.size(1)
            output = torch.empty_like(query_states)
            for head in range(query_states.size(1)):
                q = query_states[:, head, :, :].unsqueeze(1)
                k = key_states[:, head, :, :].unsqueeze(1)
                v = value_states[:, head, :, :].unsqueeze(1)
                output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
            
            attn_output = output.transpose(1, 2).contiguous()
            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
            attn_output = self.o_proj(attn_output)
            return attn_output, None, past_key_value
        
        else:
            output = torch.empty_like(query_states)
            for head in range(query_states.size(1)):
                q = query_states[:, head, :, :].unsqueeze(1)
                k = key_states[:, head, :, :].unsqueeze(1)
                v = value_states[:, head, :, :].unsqueeze(1)
                if is_flash_attn_2_available():
                    attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q.shape[2], self.head_dim)
                else:
                    attn_output = gather_qkv(q, k, v, attention_mask)
                output[:, head:head + 1] = attn_output
            attn_output = output.transpose(1, 2).contiguous()
            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
            attn_output = self.o_proj(attn_output)

            return attn_output, None, past_key_value

    return forward

def compare_list(pattern1, pattern2, layer = 10):
    pattern1 = json.load(open("Llama_3_8B_Instruct_262k_kv_out_v32_fit_101_best_pattern.json"))
    pattern2 = json.load(open("Llama_3_8B_Instruct_262k_kv_out_v32_fit_4_best_pattern.json"))
    layer = 1
    for head in range(32):
        if pattern1[layer][str(head)][:-1] != pattern2[layer][str(head)][:-1]:
            print(head, pattern1[layer][str(head)], pattern2[layer][str(head)])