Commit beba296
fix(model): always build SWA mask to preserve causal masking at all seq lengths
The previous optimization skipped get_swa() when sq <= window_size[0] + 1,
assuming the window covered the full sequence. But SWA layers use
attn_mask_type=arbitrary, which routes through ScaledSoftmax (plain softmax,
no causal mask) when mask=None — dropping causal masking entirely for short
sequences, not just the SWA restriction.
get_swa() encodes the causal triangular structure via triu/tril and
degenerates to a pure causal mask when the window covers all positions,
so always calling it is both correct and sufficient.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>1 parent d5d0133 commit beba296
2 files changed
Lines changed: 12 additions & 14 deletions
File tree
- src/megatron/bridge/models/gemma
- tests/unit_tests/models/gemma
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | | - | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
259 | 260 | | |
260 | 261 | | |
261 | 262 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
255 | 255 | | |
256 | 256 | | |
257 | 257 | | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
264 | 264 | | |
265 | 265 | | |
266 | | - | |
267 | 266 | | |
268 | 267 | | |
269 | 268 | | |
| |||
346 | 345 | | |
347 | 346 | | |
348 | 347 | | |
349 | | - | |
350 | 348 | | |
351 | | - | |
352 | 349 | | |
353 | 350 | | |
354 | 351 | | |
| |||
0 commit comments