Skip to content

[Grug] Add TE packed segment attention for dynamic segments#5749

Draft
dlwh wants to merge 15 commits into
mainfrom
codex/grug-attn-8192-segments
Draft

[Grug] Add TE packed segment attention for dynamic segments#5749
dlwh wants to merge 15 commits into
mainfrom
codex/grug-attn-8192-segments

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented May 15, 2026

Route dynamic packed Grug segment masks on GPU through Transformer Engine THD attention instead of dense JAX masks. Thread the packed-segment bound from loaders, check valid-token numerics against reference, and add GH200 benchmarks against Megatron's TE fused path.

@dlwh dlwh added the agent-generated Created by automation/agent label May 15, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Spec for the packed segment attention change.

Problem: Grug packed segment attention needs segment-id semantics at 8192 sequence length without materializing a full [8192,8192] mask. The exact split-reference route avoids the full mask but measured about 6.42 ms fwd+bwd on GH200, much slower than Transformer Engine fused attention at about 1.17 ms.

Approach: AttentionMask now carries packed_segment_length metadata for equal-length packed segments. Static segment ids are validated; dynamic ids may use this metadata as an explicit promise. The gpu_cudnn implementation detects uniform packed segment ids, reshapes [B,S,H,D] into [B*segments,L,H,D], drops redundant sliding-window state when it covers the whole segment, and calls jax.nn.dot_product_attention(..., implementation="cudnn") on each segment. gpu_xla remains the exact split-reference path for strict validation.

Key code: the important dispatch is in lib/levanter/src/levanter/grug/attention.py: _uniform_static_segment_length, _split_uniform_segments, _mask_for_split_segment, _reference_attention_uniform_segments, and _dpa_attention_uniform_segments. Tokamax flex/flash helpers reuse the same packed-segment metadata in lib/levanter/src/levanter/grug/flex_attention.py.

Tests: lib/levanter/tests/grug/test_attention.py covers uniform segment splitting without full-size mask materialization, gradient parity for the exact path, Tokamax packed segment behavior, and the empty-mask handling for packed metadata. Local validation passed ./infra/pre-commit.py --all-files --fix and uv run --with pytest --with pytest-xdist --package marin-levanter --extra kernels pytest lib/levanter/tests/grug/test_attention.py -q.

Benchmarks: On GH200, Grug gpu_cudnn at B=1 S=8192 segments=4 L=2048 Hq=32 Hkv=8 D=128 bf16 measured 0.382 ms fwd and 1.357 ms fwd+bwd. Target-shape deltas versus exact split reference were value max 0.015625, value mean 0.000221, grad max 0.0625, grad mean 0.000383. Pinned TE fused/default measured about 1.166 ms fwd+bwd; TE FlashAttention v2 measured about 1.887 ms fwd+bwd. The current path is about 16% slower than TE fused, but faster than TE FlashAttention v2 and much faster than the exact split-reference path.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from f2bb2d4 to f96378f Compare May 15, 2026 04:06
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Clarification after checking dynamic segment IDs: the gpu_cudnn reshape fast path is only valid for dynamic segment-id arrays when packed_segment_length is a static promise that q and kv segment ids are the same object and fixed-size block-aligned runs. Variable-length dynamic segment IDs cannot use the reshape; those need the existing Tokamax range-mask path or a future custom kernel. I added tests for both dynamic uniform packed IDs and dynamic variable-length contiguous IDs through the range-mask path.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from f96378f to e8c7300 Compare May 15, 2026 05:19
@dlwh dlwh changed the title [Grug] Add packed segment GPU attention fast path [Grug] Add dynamic segment GPU attention path May 15, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Update after checking the data-loader path: the real packed case is dynamic contiguous document runs, not fixed block-aligned chunks. PackedTokenDataset / ChatDataset pass GreedyPrepackedDataset segment IDs through to GrugLmExample.causal; those IDs use global doc IDs and -1 tail padding, so this PR now treats gpu_flash_triton with Tokamax range masks as the relevant dynamic path. The GH200 loader-like 8192 timing is still pending because both GH200 nodes are currently occupied by CoreWeave cw-hpc-verification GPU pods; I did not preempt those.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from e8c7300 to da8b129 Compare May 15, 2026 05:35
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Narrowed this after the static-segment correction. The runtime path no longer supports or validates static/block-aligned segment IDs as a fast path: packed_segment_length and the split-cuDNN segment implementation are removed. GPU default attention now routes structured segment IDs through the dynamic Tokamax range-mask path, and the benchmark only generates loader-like variable contiguous runs with optional -1 tail padding. Local validation is still green: ruff, pyrefly, pre-commit, and test_attention.py.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from da8b129 to 6f35a23 Compare May 15, 2026 06:45
@dlwh dlwh changed the title [Grug] Add dynamic segment GPU attention path [Grug] Add TE packed segment attention for dynamic segments May 15, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Updated this PR to use Transformer Engine JAX THD attention for dynamic packed Grug segment IDs instead of the Tokamax range-mask path as the default. Segment IDs stay dynamic; static/fixed-block IDs are not a supported target. Public JAX SDPA backends still reject segment IDs so they do not materialize a dense 8192x8192 mask.

GH200 results, NGC JAX 26.04, B=1 S=8192 Hq=32 Hkv=8 D=128 bf16, loader-like lengths [815, 1609, 322, 1115, 1909, 622, 1415, 129], 256 pad-tail tokens:

  • gpu_te max_segments_per_seq=64: value max 0.015625 mean 0.000278661; grad max 0.0625 mean 0.000474363; fwd 0.295 ms; fwd+bwd 1.252 ms.
  • gpu_te max_segments_per_seq=8: fwd 0.272 ms; fwd+bwd 1.201 ms.
  • Megatron-style PyTorch TE fused THD loader comparison in NGC PyTorch 26.04: fwd 0.265 ms; fwd+bwd 1.142 ms.

Validation: ./infra/pre-commit.py --all-files --fix; pytest lib/levanter/tests/grug/test_attention.py -q; py_compile for the benchmark scripts. The benchmark timing loop now synchronizes once after the measured JAX loop so it is comparable to the CUDA-event PyTorch TE script.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from 6f35a23 to 086d7e6 Compare May 15, 2026 08:01
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 B>1 THD follow-up:

I measured the dynamic loader-like B=2, S=8192 case on GH200 with 8 segments per row and 256 padded tail tokens per row. The checked-in JAX TE path stays correct and does not materialize a full mask, but it currently falls back to one verified THD call per batch row because JAX TE's natural batched THD descriptor is wrong for this case.

Numbers:

  • JAX TE row fallback, max_segments_per_seq=8: forward 0.950 ms, forward+backward 2.893 ms
  • JAX TE row fallback, max_segments_per_seq=64: forward 0.977 ms, forward+backward 2.962 ms
  • Target deltas vs split reference: value max 0.015625, mean 0.000280; q/k/v grad max [0.03125, 0.0625, 0.0625], means [0.000290, 0.000578, 0.000561]
  • PyTorch TE fused THD loader comparison: forward 0.434 ms, forward+backward 2.114 ms

Rejected faster JAX alternatives:

  • Direct batched TE THD descriptor: forward 0.599 ms, fwd+bwd 2.463 ms, but value max 2.890625 and huge grad errors
  • Flattened physical offsets: forward 0.514 ms, but value max 3.01953125
  • Compact valid-token THD stream: exact match to the row fallback, but slower (0.952 -> 1.305 ms forward; 3.758 -> 4.880 ms same-method fwd+bwd)

Conclusion: THD itself is not too expensive, but the JAX TE wrapper cannot currently use the fast Megatron/PyTorch-style batched THD packing correctly. The PR is good numerically and avoids the full mask, but for B=2 it is about 1.37x slower than PyTorch TE fused on fwd+bwd, so it is not within the ~5% Megatron gap target for B>1 yet.

@dlwh dlwh force-pushed the codex/grug-attn-8192-segments branch from 086d7e6 to 1f1232d Compare May 15, 2026 08:30
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Correction on the B>1 packing question: the bad batched result was our packing, not a TE bug.

The controlled probe showed TE's row-local THD offsets are segment starts padded with -1. For a four-segment row TE's own from_segment_ids_and_pos path produced [0, 17, 46, 77, -1]; our previous helper produced [0, 17, 46, 77, 88]. For the flattened fallback, offsets also need to be physical positions so inter-row tail padding is skipped rather than compact cumulative positions.

I changed the implementation accordingly:

  • Full-causal dynamic segment IDs now use SequenceDescriptor.from_segment_ids_and_pos after converting arbitrary loader IDs to per-row local segment IDs and segment positions.
  • True sliding-window segment attention keeps a flattened physical-offset descriptor so we do not route TE's descriptor construction through a quadratic mask.
  • Added focused tests for local segment-id/position generation, row-local offsets, and flattened offsets that skip inter-row padding.

Final GH200 B=2 S=8192 Hq=32 Hkv=8 D=128 bf16, 8 dynamic loader-like segments, 256 tail padding per row:

  • JAX TE segment-position descriptor, max_segments_per_seq=8: forward 0.572 ms, forward+backward 2.305 ms
  • JAX TE segment-position descriptor, max_segments_per_seq=64: forward 0.565 ms, forward+backward 2.325 ms
  • Target deltas vs split reference: value max 0.015625, mean 0.000280405; grad max 0.0625, mean 0.000476356
  • PyTorch TE fused THD loader comparison remains forward 0.434 ms, forward+backward 2.114 ms

So the corrected JAX path is about 9% slower than PyTorch TE fused on B=2 fwd+bwd, not within the original ~5% target yet, but much better than the row fallback and no dense full mask is materialized.

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 15, 2026

🤖 Follow-up on the OSS-kernel fallback path and TE descriptor variant:

  • Tokamax Triton range-mask at B=2/S=8192 is correct in the same TE-like BF16 envelope, but fwd+bwd is too slow: fwd 0.692 ms, fwd+bwd 12.638 ms.
  • Kvax blockwise document masks avoid an SxS mask, but are also too slow here: kernel-only fwd 1.915 ms, fwd+bwd 6.148 ms; with blockwise mask construction fwd 1.954 ms, fwd+bwd 6.321 ms.
  • A corrected row-local TE seqlens/offset descriptor validates numerically, but is not faster than the existing segment-position descriptor: fwd 0.593 ms, fwd+bwd 2.329 ms versus the checked-in segment-position result at fwd 0.572 ms, fwd+bwd 2.305 ms.

I pushed commit 7ea011a documenting these comparisons in docs/debug-log-grug-attention-correctness.md. The checked-in implementation remains the faster TE segment-position path for full-causal dynamic segment IDs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant