2222from jaxtyping import Float , Int
2323from torch import Tensor
2424
25+ from glm_experiments .models .components .attention import scaled_dot_product_attention
26+
2527
2628class Linear (nn .Module ):
2729 def __init__ (self , d_in : int , d_out : int ):
@@ -124,7 +126,7 @@ def forward(self, x):
124126
125127
126128class MultiHeadSelfAttention (nn .Module ):
127- """Multi-Head Self-Attention with configurable causal masking.
129+ """Multi-Head Self-Attention with configurable causal masking and sliding window .
128130
129131 This function implements section 3.2.2 of the Transformer paper. In particular,
130132 given an input tensor of shape `(batch_size, sequence_length, d_model)`, we project
@@ -141,6 +143,9 @@ class MultiHeadSelfAttention(nn.Module):
141143 The RoPE module to use.
142144 is_causal: bool
143145 Whether to use causal masking (default: False for bidirectional attention).
146+ sliding_window: int | None
147+ Window size for sliding window attention. If None, uses standard attention
148+ (default: None).
144149
145150 Returns:
146151 Tensor of shape `(batch_size, sequence_length, d_model)`.
@@ -152,12 +157,14 @@ def __init__(
152157 num_heads : int ,
153158 positional_encoder : RotaryEmbedding ,
154159 is_causal : bool = False ,
160+ sliding_window : int | None = None ,
155161 ):
156162 super ().__init__ ()
157163 assert d_model % num_heads == 0
158164 self .d_model = d_model
159165 self .num_heads = num_heads
160166 self .is_causal = is_causal
167+ self .sliding_window = sliding_window
161168
162169 self .d_k = d_model // num_heads
163170 self .d_v = self .d_k
@@ -207,8 +214,13 @@ def forward(
207214 K = self .positional_encoder (K , token_positions )
208215
209216 # Shape: (..., num_heads, sequence_length, d_k)
210- attn_output = F .scaled_dot_product_attention (
211- query = Q , key = K , value = V , is_causal = self .is_causal , enable_gqa = False
217+ attn_output = scaled_dot_product_attention (
218+ query = Q ,
219+ key = K ,
220+ value = V ,
221+ is_causal = self .is_causal ,
222+ sliding_window = self .sliding_window ,
223+ enable_gqa = False ,
212224 )
213225
214226 # Concatenate the attention output from all heads.
@@ -240,6 +252,8 @@ class TransformerBlock(nn.Module):
240252 The RoPE module to use.
241253 is_causal: bool
242254 Whether to use causal masking (default: False).
255+ sliding_window: int | None
256+ Window size for sliding window attention (default: None).
243257
244258 Returns:
245259 FloatTensor of shape `(batch_size, sequence_length, d_model)`.
@@ -252,13 +266,15 @@ def __init__(
252266 d_ff : int ,
253267 positional_encoder : RotaryEmbedding ,
254268 is_causal : bool = False ,
269+ sliding_window : int | None = None ,
255270 ):
256271 super ().__init__ ()
257272 self .attn = MultiHeadSelfAttention (
258273 d_model = d_model ,
259274 num_heads = num_heads ,
260275 positional_encoder = positional_encoder ,
261276 is_causal = is_causal ,
277+ sliding_window = sliding_window ,
262278 )
263279 self .ffn = SwiGLU (d_model = d_model , d_ff = d_ff )
264280 self .ln1 = nn .RMSNorm (d_model )
@@ -307,6 +323,12 @@ class Transformer(nn.Module):
307323 RoPE frequency base (default: 10000.0).
308324 is_causal: bool
309325 Enable causal masking (default: False for MLM).
326+ sliding_window: list[int | None] | None
327+ Per-layer window sizes for sliding window attention. Can be:
328+ - None: No sliding window (standard attention for all layers)
329+ - List of length n_layers: Specific window size per layer (None = standard attention)
330+ Example: [None, 256, 256, 128] for 4 layers
331+ (default: None).
310332 context_length: int
311333 Maximum sequence length for RoPE cache (default: 512).
312334 """
@@ -319,6 +341,7 @@ def __init__(
319341 d_ff : int | None = None ,
320342 rope_theta : float = 10000.0 ,
321343 is_causal : bool = False ,
344+ sliding_window : list [int | None ] | None = None ,
322345 context_length : int = 512 ,
323346 ):
324347 super ().__init__ ()
@@ -327,6 +350,18 @@ def __init__(
327350 self .num_heads = num_heads
328351 self .is_causal = is_causal
329352
353+ # Process sliding_window parameter
354+ if sliding_window is None :
355+ # No sliding window for any layer
356+ self .sliding_window = [None ] * n_layers
357+ else :
358+ # Validate list length
359+ if len (sliding_window ) != n_layers :
360+ raise ValueError (
361+ f"sliding_window list must have length { n_layers } , got { len (sliding_window )} "
362+ )
363+ self .sliding_window = sliding_window
364+
330365 # Auto-compute d_ff using CS336 formula: floor(d_model * 8/3 / 64) * 64
331366 if d_ff is None :
332367 d_ff = int (hidden_size * 8 / 3 / 64 ) * 64
@@ -339,7 +374,7 @@ def __init__(
339374 context_length = context_length , dim = d_head , theta = rope_theta
340375 )
341376
342- # Stack of transformer blocks
377+ # Stack of transformer blocks with per-layer sliding windows
343378 self .layers = nn .ModuleList (
344379 [
345380 TransformerBlock (
@@ -348,8 +383,9 @@ def __init__(
348383 d_ff = d_ff ,
349384 positional_encoder = self .positional_encoder ,
350385 is_causal = is_causal ,
386+ sliding_window = self .sliding_window [i ],
351387 )
352- for _ in range (n_layers )
388+ for i in range (n_layers )
353389 ]
354390 )
355391
0 commit comments