improve gdn mtp bf16 state perf for BS<=8 with LDG.128#3143
improve gdn mtp bf16 state perf for BS<=8 with LDG.128#3143ameynaik-hub wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:
1) Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
`cache_key not in _compiled_kernels_*` block, and cached alongside the
compiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
2) Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200
Command:
python benchmarks/bench_gdn_decode.py \
--batch-size 1 4 8 16 32 64 128 256 512 \
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.71
4 | 5.89
8 | 9.18
16 | 14.98
32 | 26.56
64 | 48.24
128 | 89.66
256 | 172.35
512 | 337.60
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…stic fix
Three coordinated changes to the BF16 GDN MTP decode path on B200:
1. New kernel `gated_delta_rule_mtp_wide_vec`
(flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
- 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
- vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
- ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
- No TMA, no persistent CTAs; the win is purely wider LSU transactions
- Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64
2. Auto-dispatcher inside `gated_delta_rule_mtp`
(flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
- New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
- Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
- T=1 and small-batch callers keep the existing baseline path verbatim,
so the public API is source-and-ABI compatible.
3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
- Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
inline instead of reading from sGB (only T>2 pre-populates sGB). The
inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
work_units.
- Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
covering the recompute latency.
- Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).
Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):
Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 21.20 27.23 33.73 41.23 47.89 55.15 62.11
32 26.37 37.81 49.98 62.93 78.46 91.20 103.73 117.58
64 47.76 70.69 93.76 117.92 146.42 172.80 197.82 225.44
128 90.38 135.17 180.46 226.99 278.78 329.41 378.18 432.17
256 173.65 262.98 351.06 440.75 542.37 641.63 739.60 846.40
512 337.98 516.24 691.10 869.02 ERROR ERROR ERROR ERROR
Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 23.71 30.02 37.84 46.82 54.50 61.95 69.89
32 26.37 42.56 54.75 69.01 85.41 99.74 114.13 128.50
64 47.76 79.18 101.38 129.02 159.44 187.44 216.08 244.03
128 90.38 149.34 193.09 247.38 305.87 361.33 417.20 473.81
256 173.65 289.12 375.31 481.41 596.72 705.98 816.64 930.16
512 337.98 568.32 741.25 952.97 ERROR ERROR ERROR ERROR
Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
2 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
4 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
8 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
16 1.00x 1.12x 1.10x 1.12x 1.14x 1.14x 1.12x 1.13x
32 1.00x 1.13x 1.10x 1.10x 1.09x 1.09x 1.10x 1.09x
64 1.00x 1.12x 1.08x 1.09x 1.09x 1.08x 1.09x 1.08x
128 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
256 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
512 1.00x 1.10x 1.07x 1.10x ERROR ERROR ERROR ERROR
DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 15.3 14.6 18.4 20.0 18.1 19.1 20.1 20.9
2 24.9 24.9 29.8 32.2 30.0 31.2 32.6 33.9
4 36.6 33.1 40.4 42.6 39.9 41.2 42.0 43.2
8 47.3 46.8 50.0 50.4 49.0 49.7 50.0 50.6
16 55.7 60.1 62.5 63.2 62.0 62.3 61.9 61.8
32 64.3 67.4 68.1 67.7 65.2 65.5 65.8 65.3
64 70.9 72.1 72.6 72.3 69.9 69.1 69.0 68.1
128 75.0 75.4 75.5 75.1 73.4 72.5 72.2 71.1
256 78.1 77.6 77.6 77.3 75.5 74.4 73.8 72.6
512 80.2 79.0 78.8 78.4 ERROR ERROR ERROR ERROR
Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).
Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original T=1 / T=2 small-batch heuristic from 192fa39 / 7b7f1ac regressed because the _get_bf16_mtp_config helper was dropped in the conflict resolution (upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2 regression that would otherwise show in CI. Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py: 1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8. - ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %) - Minimal signature: single h0_indices (read == write). Split-pool writes are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs h0_out_indices, so we skip the extra plumbing. 2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the new kernel. 3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper: - work_units <= 128: (min(16, v_dim), 4) # B<=2 at HV=64 - seq_len == 2 and work_units <= 256: (tile_v, 4) # covers B=4 HV=64 - else: (tile_v, 8) The T=2 branch compensates for the ILP=8 pipeline stall from the inline g/beta recompute path (sGB is only populated for T > 2 in the ILP=8 kernel; at small work_units ILP=8's 37 % occupancy cannot hide the softplus+exp+log+sigmoid latency). 4. gated_delta_rule_mtp dispatcher: - When output_state_indices is None: pick (tile_v, ilp_rows) via _get_bf16_mtp_config and route to the matching launcher. - When output_state_indices is not None: force ilp_rows=8 so the split-pool write path (h0_out_indices) stays on the ILP=8 kernel that supports it. - Cache key now includes ilp_rows so the two launchers don't collide. Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI, 5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True for T>=2; T=1 uses state-update ON): post-rebase before post ILP=4 re-add B T (us) (us) 1 1 3.46 3.46 1 2 5.43 5.41 2 1 4.26 4.26 2 2 6.51 6.50 4 1 5.82 5.82 4 2 11.36 <- +17 % regr. 9.68 <- back to reference 8 1 9.25 9.25 8 2 13.95 13.95 16 1 15.06 15.06 16 2 21.50 (wide_vec) 21.50 (wide_vec) Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes needed — the existing test exercises the dispatcher and both ILP paths. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128 so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a 148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down into small-batch sizes that were previously starved of SM parallelism. Changes: 1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr: - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from module-level constants to per-kernel constexprs. - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v + num_v_tiles to the kernel. - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128` (keyword-only default; positional callers unchanged). - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry. - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128), so every tile_v variant uses the same per-thread memory-op widths. - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128 (verified: max_abs_diff = 0.0 over 252 pytest configs). 2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)` helper picks tile_v by work_units = B * HV: work_units >= 1024 -> tile_v = 128 (unchanged: B >= 16 at HV=64) work_units >= 512 -> tile_v = 64 (new: B = 8 at HV=64) work_units >= 128 -> tile_v = 32 (new: B = 2..4 at HV=64) below -> None (baseline ILP=4/8 as before) `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope constant kept for external benchmarks that monkey-patch it; actual dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp` now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`. 3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel` parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v = 378. Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters, cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged): post- 3127 before post tile_v picker B T (us) (us) speedup 1 2 5.44 5.46 1.00x (picker -> tv=32 ties baseline) 2 2 6.40 6.34 1.01x (picker -> tv=32) 4 2 9.63 8.14 1.18x (picker -> tv=32) 8 2 13.63 12.74 1.07x (picker -> tv=64) 16 2 21.20 21.18 1.00x (tv=128, unchanged) 32 2 37.81 37.89 1.00x (tv=128, unchanged) 64 2 70.69 70.69 1.00x (tv=128, unchanged) Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4 kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x. Correctness: 378 / 378 new pytest configs pass (test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed during development (not in CI to keep the default matrix size reasonable). AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enhances the GDN decode implementation by adding support for pool-based state management with separate input and output indices and introducing two new kernel variants: an ILP=4 MTP kernel for small batches and a wide-vector MTP kernel for large work units. Feedback primarily targets the wide-vector implementation, highlighting the need to maintain consistent in-place state updates when caching is enabled and to align the function's return type with the existing production API. Additionally, suggestions were made to optimize shared memory allocation and improve global memory throughput by vectorizing scalar loads and stores.
| if cutlass.const_expr( | ||
| not disable_state_update and not cache_intermediate_states | ||
| ): |
There was a problem hiding this comment.
Skipping the final state write-back when cache_intermediate_states is enabled breaks consistency with other kernels and the expected behavior of in-place state updates. The initial_state_source should always be updated if disable_state_update is False, regardless of whether intermediate states are being cached.
| if cutlass.const_expr( | |
| not disable_state_update and not cache_intermediate_states | |
| ): | |
| if cutlass.const_expr( | |
| not disable_state_update | |
| ): |
| scale: Optional[float] = None, | ||
| output: Optional[torch.Tensor] = None, | ||
| tile_v: int = 128, | ||
| ) -> torch.Tensor: |
| if not intermediate_states.is_contiguous(): | ||
| intermediate_states = intermediate_states.contiguous() | ||
| # Skip the redundant final writeback when caching is on. | ||
| effective_disable_final = True |
There was a problem hiding this comment.
| cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 | ||
| ) | ||
| sGB = smem.allocate_tensor( | ||
| cutlass.Float32, cute.make_layout((T, 3), stride=(3, 1)), 16 |
There was a problem hiding this comment.
| v_val0 = cutlass.Float32(v[(i_n, i_t, i_hv, v0)]) | ||
| v_val1 = cutlass.Float32(v[(i_n, i_t, i_hv, v1)]) | ||
| v_val2 = cutlass.Float32(v[(i_n, i_t, i_hv, v2)]) | ||
| v_val3 = cutlass.Float32(v[(i_n, i_t, i_hv, v3)]) |
| o[(i_n, i_t, i_hv, v0)] = cutlass.BFloat16(o0) | ||
| o[(i_n, i_t, i_hv, v1)] = cutlass.BFloat16(o1) | ||
| o[(i_n, i_t, i_hv, v2)] = cutlass.BFloat16(o2) | ||
| o[(i_n, i_t, i_hv, v3)] = cutlass.BFloat16(o3) |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes