
def create_strict_diagonal_attention_mask(seq_length, window_size):
    # Initialize the full attention mask with zeros (False)
    mask = torch.zeros((seq_length, seq_length), dtype=torch.bool)

    # Apply a strict sliding window
    for i in range(seq_length):
        start = max(0, i - window_size // 2)
        end = min(seq_length, i + window_size // 2 + 1)
        mask[i, start:end] = True

    return mask

def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25, spatial_factor=0.0):

    T, H, W = shape[-3], shape[-2], shape[-1]
    if d_s == 0 or d_t == 0:
        return torch.zeros(shape)

    # Create normalized coordinate grids for T, H, W
    # Generate indices as in the old loop-based method
    t = torch.arange(T).float() * 2 / T - 1
    h = torch.arange(H).float() * 2 / H - 1
    w = torch.arange(W).float() * 2 / W - 1
    
    # Use meshgrid to create 3D grid of coordinates
    grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing='ij')

    # Compute squared distance from the center, adjusted for the frequency cut-offs
    d_square = ((grid_t * (1 / d_t)).pow(2) + (grid_h * (1 / d_s)).pow(2) + (grid_w * (1 / d_s)).pow(2))

    # Compute the Gaussian mask
    mask = torch.exp(-0.5 * d_square)

    # Adjust shape for multiple channels if necessary
    if len(shape) > 3:
        T, C = shape[0], shape[1]
        mask = mask.unsqueeze(0).unsqueeze(0).repeat(T, C, 1, 1, 1)

    return mask


def freq_mix_3d(global_feat, local_feat, LPF):

    global_feat_freq = fft.fftn(global_feat, dim=(-3, -2, -1))
    global_feat_freq = fft.fftshift(global_feat_freq, dim=(-3, -2, -1))
    local_feat_freq = fft.fftn(local_feat, dim=(-3, -2, -1))
    local_feat_freq = fft.fftshift(local_feat_freq, dim=(-3, -2, -1))

    HPF = 1 - LPF
    global_feat_freq_low = global_feat_freq * LPF
    local_feat_freq_high = local_feat_freq * HPF
    feat_freq_mixed = global_feat_freq_low + local_feat_freq_high # mix in freq domain

    feat_freq_mixed = fft.ifftshift(feat_freq_mixed, dim=(-3, -2, -1))
    feat_mixed = fft.ifftn(feat_freq_mixed, dim=(-3, -2, -1)).real

    return feat_mixed

def blend(global_feat, local_feat, video_length):

    # prepare features for combining
    freq_filter = gaussian_low_pass_filter(global_feat.shape[:], d_s=0.25, d_t=0.25).to(global_feat.device)
    hidden_states_dtype = global_feat.dtype
    hidden_states = freq_mix_3d(global_feat.to(dtype=torch.float32), local_feat.to(dtype=torch.float32), LPF=freq_filter)
    hidden_states = hidden_states.to(hidden_states_dtype)
    return hidden_states