Skip to content

Commit 4a3b583

Browse files
Successful model launch
1 parent 01261f7 commit 4a3b583

File tree

7 files changed

+667
-46
lines changed

7 files changed

+667
-46
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# @package _global_
2+
3+
# Experiment: CLM Transformer with Alternating Global/Local Sliding Window Attention
4+
#
5+
# This experiment adds alternating global and local attention to the CLM transformer.
6+
# Pattern: Global → Local (32) → Global → Local (32) → ...
7+
# Starting with global attention in layer 0, then alternating with local sliding window.
8+
#
9+
# To execute this experiment run:
10+
# python glm_experiments/train.py experiment=clm_transformer_base_sliding_window
11+
12+
defaults:
13+
- override /data: gpn_animal_promoter
14+
- override /model: clm_transformer_base
15+
- override /trainer: gpn_animal_promoter
16+
17+
logger:
18+
wandb:
19+
name: experiment-clm-transformer-base-sliding-window
20+
tags: ["experiment", "clm", "transformer", "base", "sliding-window"]
21+
22+
data:
23+
_target_: glm_experiments.data.lm_datamodule.CLMDataModule
24+
per_device_batch_size: 256
25+
26+
model:
27+
net:
28+
encoder:
29+
# Add sliding window attention with alternating global/local pattern
30+
sliding_window:
31+
_target_: glm_experiments.models.utils.attention_patterns.alternating_global_local
32+
n_layers: ${..n_layers} # Reference n_layers from encoder (12)
33+
window_size: 32 # Local attention window size
34+
start_with_global: true # First layer is global
35+
36+
scheduler:
37+
_target_: transformers.get_cosine_with_min_lr_schedule_with_warmup
38+
_partial_: true
39+
num_warmup_steps: 2000
40+
num_training_steps: ${trainer.max_steps}
41+
min_lr_rate: 0.1 # Decay to 10% of max lr
42+
43+
trainer:
44+
max_steps: 20000
45+
log_every_n_steps: 1000
46+
val_check_interval: 1000

configs/experiment/clm_transformer_small.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Short training run with small Transformer encoder for quick testing
44

55
defaults:
6-
- override /data: plants
76
- override /model: clm_transformer_small
87

98
logger:
@@ -24,6 +23,13 @@ model:
2423
d_model: 32
2524
encoder:
2625
n_layers: 2
26+
num_heads: 2
27+
sliding_window:
28+
_target_: glm_experiments.models.utils.attention_patterns.alternating_global_local
29+
n_layers: ${..n_layers}
30+
window_size: 32
31+
start_with_global: true
32+
2733
scheduler:
2834
_target_: transformers.get_cosine_schedule_with_warmup
2935
_partial_: true

glm_experiments/models/components/attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ def scaled_dot_product_attention(
129129
# FlexAttention path: use sliding window
130130
batch_size, num_heads, seq_len, head_dim = query.shape
131131

132+
# FlexAttention requires all tensors to have the same dtype.
133+
# Unlike F.scaled_dot_product_attention, flex_attention is not in the autocast
134+
# whitelist, so we need to manually ensure dtype consistency.
135+
# We match to value.dtype since autocast has already converted it to the target precision.
136+
target_dtype = value.dtype
137+
if query.dtype != target_dtype:
138+
query = query.to(target_dtype)
139+
if key.dtype != target_dtype:
140+
key = key.to(target_dtype)
141+
132142
# Determine mask type based on is_causal flag
133143
mask_type = "causal_sliding_window" if is_causal else "sliding_window"
134144

glm_experiments/models/components/transformer.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from jaxtyping import Float, Int
2323
from torch import Tensor
2424

25+
from glm_experiments.models.components.attention import scaled_dot_product_attention
26+
2527

2628
class Linear(nn.Module):
2729
def __init__(self, d_in: int, d_out: int):
@@ -124,7 +126,7 @@ def forward(self, x):
124126

125127

126128
class MultiHeadSelfAttention(nn.Module):
127-
"""Multi-Head Self-Attention with configurable causal masking.
129+
"""Multi-Head Self-Attention with configurable causal masking and sliding window.
128130
129131
This function implements section 3.2.2 of the Transformer paper. In particular,
130132
given an input tensor of shape `(batch_size, sequence_length, d_model)`, we project
@@ -141,6 +143,9 @@ class MultiHeadSelfAttention(nn.Module):
141143
The RoPE module to use.
142144
is_causal: bool
143145
Whether to use causal masking (default: False for bidirectional attention).
146+
sliding_window: int | None
147+
Window size for sliding window attention. If None, uses standard attention
148+
(default: None).
144149
145150
Returns:
146151
Tensor of shape `(batch_size, sequence_length, d_model)`.
@@ -152,12 +157,14 @@ def __init__(
152157
num_heads: int,
153158
positional_encoder: RotaryEmbedding,
154159
is_causal: bool = False,
160+
sliding_window: int | None = None,
155161
):
156162
super().__init__()
157163
assert d_model % num_heads == 0
158164
self.d_model = d_model
159165
self.num_heads = num_heads
160166
self.is_causal = is_causal
167+
self.sliding_window = sliding_window
161168

162169
self.d_k = d_model // num_heads
163170
self.d_v = self.d_k
@@ -207,8 +214,13 @@ def forward(
207214
K = self.positional_encoder(K, token_positions)
208215

209216
# Shape: (..., num_heads, sequence_length, d_k)
210-
attn_output = F.scaled_dot_product_attention(
211-
query=Q, key=K, value=V, is_causal=self.is_causal, enable_gqa=False
217+
attn_output = scaled_dot_product_attention(
218+
query=Q,
219+
key=K,
220+
value=V,
221+
is_causal=self.is_causal,
222+
sliding_window=self.sliding_window,
223+
enable_gqa=False,
212224
)
213225

214226
# Concatenate the attention output from all heads.
@@ -240,6 +252,8 @@ class TransformerBlock(nn.Module):
240252
The RoPE module to use.
241253
is_causal: bool
242254
Whether to use causal masking (default: False).
255+
sliding_window: int | None
256+
Window size for sliding window attention (default: None).
243257
244258
Returns:
245259
FloatTensor of shape `(batch_size, sequence_length, d_model)`.
@@ -252,13 +266,15 @@ def __init__(
252266
d_ff: int,
253267
positional_encoder: RotaryEmbedding,
254268
is_causal: bool = False,
269+
sliding_window: int | None = None,
255270
):
256271
super().__init__()
257272
self.attn = MultiHeadSelfAttention(
258273
d_model=d_model,
259274
num_heads=num_heads,
260275
positional_encoder=positional_encoder,
261276
is_causal=is_causal,
277+
sliding_window=sliding_window,
262278
)
263279
self.ffn = SwiGLU(d_model=d_model, d_ff=d_ff)
264280
self.ln1 = nn.RMSNorm(d_model)
@@ -307,6 +323,12 @@ class Transformer(nn.Module):
307323
RoPE frequency base (default: 10000.0).
308324
is_causal: bool
309325
Enable causal masking (default: False for MLM).
326+
sliding_window: list[int | None] | None
327+
Per-layer window sizes for sliding window attention. Can be:
328+
- None: No sliding window (standard attention for all layers)
329+
- List of length n_layers: Specific window size per layer (None = standard attention)
330+
Example: [None, 256, 256, 128] for 4 layers
331+
(default: None).
310332
context_length: int
311333
Maximum sequence length for RoPE cache (default: 512).
312334
"""
@@ -319,6 +341,7 @@ def __init__(
319341
d_ff: int | None = None,
320342
rope_theta: float = 10000.0,
321343
is_causal: bool = False,
344+
sliding_window: list[int | None] | None = None,
322345
context_length: int = 512,
323346
):
324347
super().__init__()
@@ -327,6 +350,18 @@ def __init__(
327350
self.num_heads = num_heads
328351
self.is_causal = is_causal
329352

353+
# Process sliding_window parameter
354+
if sliding_window is None:
355+
# No sliding window for any layer
356+
self.sliding_window = [None] * n_layers
357+
else:
358+
# Validate list length
359+
if len(sliding_window) != n_layers:
360+
raise ValueError(
361+
f"sliding_window list must have length {n_layers}, got {len(sliding_window)}"
362+
)
363+
self.sliding_window = sliding_window
364+
330365
# Auto-compute d_ff using CS336 formula: floor(d_model * 8/3 / 64) * 64
331366
if d_ff is None:
332367
d_ff = int(hidden_size * 8 / 3 / 64) * 64
@@ -339,7 +374,7 @@ def __init__(
339374
context_length=context_length, dim=d_head, theta=rope_theta
340375
)
341376

342-
# Stack of transformer blocks
377+
# Stack of transformer blocks with per-layer sliding windows
343378
self.layers = nn.ModuleList(
344379
[
345380
TransformerBlock(
@@ -348,8 +383,9 @@ def __init__(
348383
d_ff=d_ff,
349384
positional_encoder=self.positional_encoder,
350385
is_causal=is_causal,
386+
sliding_window=self.sliding_window[i],
351387
)
352-
for _ in range(n_layers)
388+
for i in range(n_layers)
353389
]
354390
)
355391

0 commit comments

Comments
 (0)