[Grug] Add TE packed segment attention for dynamic segments#5749
Conversation
Add the Sonic-style gather/sum and gather-ragged-dot Pallas implementations plus the Grug local-MoE integration and benchmarks used to compare the generated PTX against upstream SonicMoE.
|
🤖 Spec for the packed segment attention change. Problem: Grug packed segment attention needs segment-id semantics at 8192 sequence length without materializing a full [8192,8192] mask. The exact split-reference route avoids the full mask but measured about 6.42 ms fwd+bwd on GH200, much slower than Transformer Engine fused attention at about 1.17 ms. Approach: AttentionMask now carries packed_segment_length metadata for equal-length packed segments. Static segment ids are validated; dynamic ids may use this metadata as an explicit promise. The gpu_cudnn implementation detects uniform packed segment ids, reshapes [B,S,H,D] into [B*segments,L,H,D], drops redundant sliding-window state when it covers the whole segment, and calls jax.nn.dot_product_attention(..., implementation="cudnn") on each segment. gpu_xla remains the exact split-reference path for strict validation. Key code: the important dispatch is in lib/levanter/src/levanter/grug/attention.py: _uniform_static_segment_length, _split_uniform_segments, _mask_for_split_segment, _reference_attention_uniform_segments, and _dpa_attention_uniform_segments. Tokamax flex/flash helpers reuse the same packed-segment metadata in lib/levanter/src/levanter/grug/flex_attention.py. Tests: lib/levanter/tests/grug/test_attention.py covers uniform segment splitting without full-size mask materialization, gradient parity for the exact path, Tokamax packed segment behavior, and the empty-mask handling for packed metadata. Local validation passed ./infra/pre-commit.py --all-files --fix and uv run --with pytest --with pytest-xdist --package marin-levanter --extra kernels pytest lib/levanter/tests/grug/test_attention.py -q. Benchmarks: On GH200, Grug gpu_cudnn at B=1 S=8192 segments=4 L=2048 Hq=32 Hkv=8 D=128 bf16 measured 0.382 ms fwd and 1.357 ms fwd+bwd. Target-shape deltas versus exact split reference were value max 0.015625, value mean 0.000221, grad max 0.0625, grad mean 0.000383. Pinned TE fused/default measured about 1.166 ms fwd+bwd; TE FlashAttention v2 measured about 1.887 ms fwd+bwd. The current path is about 16% slower than TE fused, but faster than TE FlashAttention v2 and much faster than the exact split-reference path. |
f2bb2d4 to
f96378f
Compare
|
🤖 Clarification after checking dynamic segment IDs: the gpu_cudnn reshape fast path is only valid for dynamic segment-id arrays when packed_segment_length is a static promise that q and kv segment ids are the same object and fixed-size block-aligned runs. Variable-length dynamic segment IDs cannot use the reshape; those need the existing Tokamax range-mask path or a future custom kernel. I added tests for both dynamic uniform packed IDs and dynamic variable-length contiguous IDs through the range-mask path. |
f96378f to
e8c7300
Compare
|
🤖 Update after checking the data-loader path: the real packed case is dynamic contiguous document runs, not fixed block-aligned chunks. |
e8c7300 to
da8b129
Compare
|
🤖 Narrowed this after the static-segment correction. The runtime path no longer supports or validates static/block-aligned segment IDs as a fast path: |
da8b129 to
6f35a23
Compare
|
🤖 Updated this PR to use Transformer Engine JAX THD attention for dynamic packed Grug segment IDs instead of the Tokamax range-mask path as the default. Segment IDs stay dynamic; static/fixed-block IDs are not a supported target. Public JAX SDPA backends still reject segment IDs so they do not materialize a dense 8192x8192 mask. GH200 results, NGC JAX 26.04, B=1 S=8192 Hq=32 Hkv=8 D=128 bf16, loader-like lengths [815, 1609, 322, 1115, 1909, 622, 1415, 129], 256 pad-tail tokens:
Validation: ./infra/pre-commit.py --all-files --fix; pytest lib/levanter/tests/grug/test_attention.py -q; py_compile for the benchmark scripts. The benchmark timing loop now synchronizes once after the measured JAX loop so it is comparable to the CUDA-event PyTorch TE script. |
6f35a23 to
086d7e6
Compare
|
🤖 B>1 THD follow-up: I measured the dynamic loader-like B=2, S=8192 case on GH200 with 8 segments per row and 256 padded tail tokens per row. The checked-in JAX TE path stays correct and does not materialize a full mask, but it currently falls back to one verified THD call per batch row because JAX TE's natural batched THD descriptor is wrong for this case. Numbers:
Rejected faster JAX alternatives:
Conclusion: THD itself is not too expensive, but the JAX TE wrapper cannot currently use the fast Megatron/PyTorch-style batched THD packing correctly. The PR is good numerically and avoids the full mask, but for B=2 it is about 1.37x slower than PyTorch TE fused on fwd+bwd, so it is not within the ~5% Megatron gap target for B>1 yet. |
086d7e6 to
1f1232d
Compare
|
🤖 Correction on the B>1 packing question: the bad batched result was our packing, not a TE bug. The controlled probe showed TE's row-local THD offsets are segment starts padded with -1. For a four-segment row TE's own I changed the implementation accordingly:
Final GH200 B=2 S=8192 Hq=32 Hkv=8 D=128 bf16, 8 dynamic loader-like segments, 256 tail padding per row:
So the corrected JAX path is about 9% slower than PyTorch TE fused on B=2 fwd+bwd, not within the original ~5% target yet, but much better than the row fallback and no dense full mask is materialized. |
|
🤖 Follow-up on the OSS-kernel fallback path and TE descriptor variant:
I pushed commit 7ea011a documenting these comparisons in docs/debug-log-grug-attention-correctness.md. The checked-in implementation remains the faster TE segment-position path for full-causal dynamic segment IDs. |
Route dynamic packed Grug segment masks on GPU through Transformer Engine THD attention instead of dense JAX masks. Thread the packed-segment bound from loaders, check valid-token numerics against reference, and add GH200 benchmarks against Megatron's TE fused path.