Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configuration_bibo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
self.kernel_size = kernel_size
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.conv_router = conv_router
self.use_ssmax = use_ssmax
if mlp_only_layers is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above two lines have to be added to make sure that the file compiles

self.mlp_only_layers = [0, num_hidden_layers - 1]
else:
Expand Down
157 changes: 144 additions & 13 deletions modeling_bibo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from einops import rearrange
from einops import rearrange, einsum

from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
Expand Down Expand Up @@ -701,9 +701,146 @@ def forward(
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value

def eager_sliding_window_attention(self):
pass

def eager_sliding_window_attention(
self,
query_states: torch.Tensor, # (b, h, q_len, d)
key_states: torch.Tensor, # (b, h, kv_len, d)
value_states: torch.Tensor, # (b, h, kv_len, d)
attention_mask: Optional[torch.Tensor], # (b, 1, q_len, kv_len) additive
window_size: int,
stride: Optional[int] = None
) -> torch.Tensor:
"""
Efficient sliding window attention using tensor unfolding and einops,
matching the logic of the provided loop-based implementation.

Args:
query_states: Query tensor [batch_size, num_heads, q_len, head_dim]
key_states: Key tensor [batch_size, num_heads, kv_len, head_dim]
value_states: Value tensor [batch_size, num_heads, kv_len, head_dim]
attention_mask: Optional additive mask tensor (-inf for masked)
Shape [batch_size, 1, q_len, kv_len].
window_size: Size of the attention window.

Returns:
Attention output tensor [batch_size, num_heads, q_len, head_dim]
"""
batch_size, num_heads, q_len, head_dim = query_states.shape
kv_len = key_states.shape[-2]
# Ensure kv_len matches query_states if expected by causal window logic
# assert q_len == kv_len, "This specific unfold logic assumes q_len == kv_len for simplicity"

# --- 1. Pad Key and Value tensors (Left Padding) ---
# Pad by window_size - 1 on the left of the sequence dimension (dim 2)
kv_padding_size = max(0, window_size - 1)
# Pad format: (pad_left, pad_right) for last dim, then second-to-last, etc.
# We pad dim 2 (sequence length): (pad_seq_left, pad_seq_right)
kv_padding = (0, 0, kv_padding_size, 0) # (pad_dim3_l, pad_d3_r, pad_dim2_l, pad_d2_r)

padded_key_states = F.pad(key_states, kv_padding)
padded_value_states = F.pad(value_states, kv_padding)
# Padded shape: [b, h, kv_len + window_size - 1, d]

# --- 2. Unfold Padded Key/Value tensors ---
# Create sliding windows of size `window_size` along dim 2 with step 1
unfolded_key = padded_key_states.unfold(dimension=2, size=window_size, step=1)
unfolded_value = padded_value_states.unfold(dimension=2, size=window_size, step=1)
# Shape after unfold: [b, h, num_windows, d, w]
# num_windows = (kv_len + window_size - 1) - window_size + 1 = kv_len
# If q_len == kv_len, then num_windows == q_len

# Handle potential mismatch if q_len != kv_len (unlikely for standard SWA)
num_windows = unfolded_key.shape[2]
if num_windows != q_len:
print(f"Warning: q_len ({q_len}) != kv_len ({num_windows}). Check logic if this is intended. Taking first {q_len} windows.")
unfolded_key = unfolded_key[:, :, :q_len, :, :]
unfolded_value = unfolded_value[:, :, :q_len, :, :]

# --- 3. Rearrange with einops ---
# Rearrange to [b, h, q_len, window_size, head_dim] for matmul convenience
unfolded_key = rearrange(unfolded_key, 'b h q d w -> b h q w d')
unfolded_value = rearrange(unfolded_value, 'b h q d w -> b h q w d')

# --- 4. Compute Attention Scores within Windows ---
# Scale query beforehand as in the loop version
# query_scaled = query_states * (self.head_dim ** -0.5) # Option 1: Scale Q
# attn_scores_windowed = einsum(query_scaled, unfolded_key, 'b h q d, b h q w d -> b h q w')

# Option 2: Scale scores after matmul (more common)
scale_factor = self.head_dim ** -0.5
attn_scores_windowed = einsum(query_states, unfolded_key, 'b h q d, b h q w d -> b h q w') * scale_factor
# Shape: [b, h, q, w]

# --- 5. Apply Masking ---
# a) Mask attention to padded key positions introduced by F.pad
# Calculate original key indices corresponding to each window position
relative_indices = torch.arange(window_size, device=query_states.device) # (w,)
query_indices = torch.arange(q_len, device=query_states.device).unsqueeze(1) # (q, 1)
# For query i, window position k, the key index in the padded tensor is i+k.
# The original key index is (i+k) - kv_padding_size
# More directly: original key index = query_idx - (window_size - 1) + window_relative_idx
original_key_indices = query_indices - kv_padding_size + relative_indices # Shape (q, w)
# Mask if the original key index is negative (came from padding)
padding_mask_window = (original_key_indices < 0) # Shape (q, w) boolean

# Apply this padding mask (additive -inf)
attn_scores_windowed = attn_scores_windowed.masked_fill(
padding_mask_window.unsqueeze(0).unsqueeze(0), # Expand to (1, 1, q, w)
float('-inf')
)

# b) Apply the external attention_mask if provided
# This requires selecting the mask values corresponding to the keys in each window.
# Replicating the loop version's logic by constructing the windowed mask.
if attention_mask is not None:
# Input mask shape: (b, 1, q, k), additive (-inf where masked)
# Output needed: (b, h, q, w) mask values corresponding to windowed keys
windowed_attention_mask = torch.zeros_like(attn_scores_windowed) # Start with 0 (no mask)

# Ensure mask has correct dimensions for slicing
mask_for_loop = attention_mask
if mask_for_loop.shape[2] != q_len or mask_for_loop.shape[3] != kv_len:
mask_for_loop = mask_for_loop[:,:,:q_len, :kv_len] # Adjust slice if needed

for i in range(q_len):
# Determine the slice of keys in the *original* sequence for query i's window
k_start = max(0, i - window_size + 1)
k_end = i + 1 # exclusive end index

# Extract the relevant slice from the original attention_mask
# Mask for query i, attending to keys k_start to k_end-1
# Shape: (b, 1, 1, actual_window_len)
mask_slice = mask_for_loop[:, :, i:i+1, k_start:k_end]

# Pad the mask slice on the left if the window was truncated at the start
actual_window_len = k_end - k_start
left_padding_needed = window_size - actual_window_len
# Pad format (left, right) for the last dimension. Pad with 0 for additive mask.
padded_mask_slice = F.pad(mask_slice, (left_padding_needed, 0), value=0.0)

# Assign to the correct position in the windowed mask tensor
# Squeeze the query dim (dim 2) from the slice before assigning
windowed_attention_mask[:, :, i, :] = padded_mask_slice.squeeze(2)

# Add the constructed windowed mask to the scores
attn_scores_windowed = attn_scores_windowed + windowed_attention_mask

# --- 6. Compute Attention Probabilities ---
# Softmax is applied over the window dimension (-1)
attn_weights = F.softmax(attn_scores_windowed, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Shape: [b, h, q, w]

# Apply dropout (as in the loop version)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)

# --- 7. Compute Output using einsum ---
# Weighted sum of the unfolded values based on the attention weights
# weights[b,h,q,w] * values[b,h,q,w,d] -> output[b,h,q,d]
attn_output = einsum(attn_weights, unfolded_value, 'b h q w, b h q w d -> b h q d')
# Final shape: [batch_size, num_heads, q_len, head_dim]

return attn_output


def eager_standard_attention(
Expand All @@ -729,6 +866,7 @@ def eager_standard_attention(


kv_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_ssmax:
log_n = torch.log(torch.clamp(torch.tensor(kv_len, device=query_states.device, dtype=self.ssmax_scale.dtype), min=2.0))
# min=2.0 since log(1) = 0 and negative for <1
Expand All @@ -738,11 +876,9 @@ def eager_standard_attention(
# SSMax Ratio: exp(C * z_i) / exp(C * z_k) = exp(C * z_i - C * z_k) = exp(C * (z_i - z_k)) = (exp(z_i - z_k))^C
# C is scaling factor i.e s*log(seq_len) ;
# in a gist: a learnable, seq-len adaptive temperature applied per head to control attention sharpness, preventing fading in long contexts.
s_scaled = self.s.view(1, self.num_heads, 1, 1) * log_n



attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = attn_weights * s_scaled
attn_weights = attn_weights * s_scaled

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
Expand All @@ -752,8 +888,3 @@ def eager_standard_attention(
attn_output = torch.matmul(attn_weights, value_states)

return attn_output





Loading