Conversation
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:
1) Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
`cache_key not in _compiled_kernels_*` block, and cached alongside the
compiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
2) Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200
Command:
python benchmarks/bench_gdn_decode.py \
--batch-size 1 4 8 16 32 64 128 256 512 \
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.71
4 | 5.89
8 | 9.18
16 | 14.98
32 | 26.56
64 | 48.24
128 | 89.66
256 | 172.35
512 | 337.60
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…stic fix
Three coordinated changes to the BF16 GDN MTP decode path on B200:
1. New kernel `gated_delta_rule_mtp_wide_vec`
(flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
- 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
- vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
- ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
- No TMA, no persistent CTAs; the win is purely wider LSU transactions
- Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64
2. Auto-dispatcher inside `gated_delta_rule_mtp`
(flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
- New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
- Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
- T=1 and small-batch callers keep the existing baseline path verbatim,
so the public API is source-and-ABI compatible.
3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
- Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
inline instead of reading from sGB (only T>2 pre-populates sGB). The
inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
work_units.
- Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
covering the recompute latency.
- Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).
Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):
Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 21.20 27.23 33.73 41.23 47.89 55.15 62.11
32 26.37 37.81 49.98 62.93 78.46 91.20 103.73 117.58
64 47.76 70.69 93.76 117.92 146.42 172.80 197.82 225.44
128 90.38 135.17 180.46 226.99 278.78 329.41 378.18 432.17
256 173.65 262.98 351.06 440.75 542.37 641.63 739.60 846.40
512 337.98 516.24 691.10 869.02 ERROR ERROR ERROR ERROR
Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 23.71 30.02 37.84 46.82 54.50 61.95 69.89
32 26.37 42.56 54.75 69.01 85.41 99.74 114.13 128.50
64 47.76 79.18 101.38 129.02 159.44 187.44 216.08 244.03
128 90.38 149.34 193.09 247.38 305.87 361.33 417.20 473.81
256 173.65 289.12 375.31 481.41 596.72 705.98 816.64 930.16
512 337.98 568.32 741.25 952.97 ERROR ERROR ERROR ERROR
Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
2 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
4 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
8 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
16 1.00x 1.12x 1.10x 1.12x 1.14x 1.14x 1.12x 1.13x
32 1.00x 1.13x 1.10x 1.10x 1.09x 1.09x 1.10x 1.09x
64 1.00x 1.12x 1.08x 1.09x 1.09x 1.08x 1.09x 1.08x
128 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
256 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
512 1.00x 1.10x 1.07x 1.10x ERROR ERROR ERROR ERROR
DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 15.3 14.6 18.4 20.0 18.1 19.1 20.1 20.9
2 24.9 24.9 29.8 32.2 30.0 31.2 32.6 33.9
4 36.6 33.1 40.4 42.6 39.9 41.2 42.0 43.2
8 47.3 46.8 50.0 50.4 49.0 49.7 50.0 50.6
16 55.7 60.1 62.5 63.2 62.0 62.3 61.9 61.8
32 64.3 67.4 68.1 67.7 65.2 65.5 65.8 65.3
64 70.9 72.1 72.6 72.3 69.9 69.1 69.0 68.1
128 75.0 75.4 75.5 75.1 73.4 72.5 72.2 71.1
256 78.1 77.6 77.6 77.3 75.5 74.4 73.8 72.6
512 80.2 79.0 78.8 78.4 ERROR ERROR ERROR ERROR
Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).
Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original T=1 / T=2 small-batch heuristic from 192fa39 / 7b7f1ac regressed because the _get_bf16_mtp_config helper was dropped in the conflict resolution (upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2 regression that would otherwise show in CI. Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py: 1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8. - ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %) - Minimal signature: single h0_indices (read == write). Split-pool writes are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs h0_out_indices, so we skip the extra plumbing. 2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the new kernel. 3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper: - work_units <= 128: (min(16, v_dim), 4) # B<=2 at HV=64 - seq_len == 2 and work_units <= 256: (tile_v, 4) # covers B=4 HV=64 - else: (tile_v, 8) The T=2 branch compensates for the ILP=8 pipeline stall from the inline g/beta recompute path (sGB is only populated for T > 2 in the ILP=8 kernel; at small work_units ILP=8's 37 % occupancy cannot hide the softplus+exp+log+sigmoid latency). 4. gated_delta_rule_mtp dispatcher: - When output_state_indices is None: pick (tile_v, ilp_rows) via _get_bf16_mtp_config and route to the matching launcher. - When output_state_indices is not None: force ilp_rows=8 so the split-pool write path (h0_out_indices) stays on the ILP=8 kernel that supports it. - Cache key now includes ilp_rows so the two launchers don't collide. Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI, 5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True for T>=2; T=1 uses state-update ON): post-rebase before post ILP=4 re-add B T (us) (us) 1 1 3.46 3.46 1 2 5.43 5.41 2 1 4.26 4.26 2 2 6.51 6.50 4 1 5.82 5.82 4 2 11.36 <- +17 % regr. 9.68 <- back to reference 8 1 9.25 9.25 8 2 13.95 13.95 16 1 15.06 15.06 16 2 21.50 (wide_vec) 21.50 (wide_vec) Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes needed — the existing test exercises the dispatcher and both ILP paths. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128 so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a 148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down into small-batch sizes that were previously starved of SM parallelism. Changes: 1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr: - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from module-level constants to per-kernel constexprs. - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v + num_v_tiles to the kernel. - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128` (keyword-only default; positional callers unchanged). - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry. - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128), so every tile_v variant uses the same per-thread memory-op widths. - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128 (verified: max_abs_diff = 0.0 over 252 pytest configs). 2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)` helper picks tile_v by work_units = B * HV: work_units >= 1024 -> tile_v = 128 (unchanged: B >= 16 at HV=64) work_units >= 512 -> tile_v = 64 (new: B = 8 at HV=64) work_units >= 128 -> tile_v = 32 (new: B = 2..4 at HV=64) below -> None (baseline ILP=4/8 as before) `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope constant kept for external benchmarks that monkey-patch it; actual dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp` now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`. 3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel` parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v = 378. Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters, cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged): post- 3127 before post tile_v picker B T (us) (us) speedup 1 2 5.44 5.46 1.00x (picker -> tv=32 ties baseline) 2 2 6.40 6.34 1.01x (picker -> tv=32) 4 2 9.63 8.14 1.18x (picker -> tv=32) 8 2 13.63 12.74 1.07x (picker -> tv=64) 16 2 21.20 21.18 1.00x (tv=128, unchanged) 32 2 37.81 37.89 1.00x (tv=128, unchanged) 64 2 70.69 70.69 1.00x (tv=128, unchanged) Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4 kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x. Correctness: 378 / 378 new pytest configs pass (test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed during development (not in CI to keep the default matrix size reasonable). AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a wide-vector BF16 GDN MTP decode kernel and an ILP=4 variant to optimize performance across various batch sizes and work units. It also refactors the dispatch logic to support pool-based state management with indices and includes updated tests. Feedback from the review focuses on the significant code duplication across the new kernels and the repetitive nature of the memory write-back logic, suggesting that these be refactored into parameterized implementations or loops to improve maintainability.
| h0_source.layout.shape[1], | ||
| h0_source.layout.shape[2], | ||
| ) | ||
| """MTP kernel (ILP=4) for BF16 state — higher occupancy at small batch.""" |
There was a problem hiding this comment.
This kernel gdn_decode_bf16state_mtp_ilp4_kernel is almost a complete copy of gdn_decode_bf16state_mtp_kernel. This introduces a significant amount of duplicated code, making maintenance difficult. A change in one kernel will likely need to be manually propagated to the other. Please consider refactoring to a common implementation parameterized by ILP_ROWS to improve maintainability.
|
|
||
|
|
||
| @cute.kernel | ||
| def gdn_wide_vec_kernel( |
There was a problem hiding this comment.
This new gdn_wide_vec_kernel is another large kernel that shares a substantial amount of logic with gdn_decode_bf16state_mtp_kernel and gdn_decode_bf16state_mtp_ilp4_kernel in gdn_decode_bf16_state.py. This extensive code duplication across three kernels (ILP=8, ILP=4, and wide_vec) poses a significant maintainability risk. Future bug fixes or logic changes will need to be carefully applied to all three places. It would be highly beneficial to explore refactoring options, such as using a code generation approach or a more parameterized base kernel, to reduce this duplication.
| wt0 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v0, lane_in_group) | ||
| ) | ||
| wt1 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v1, lane_in_group) | ||
| ) | ||
| wt2 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v2, lane_in_group) | ||
| ) | ||
| wt3 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v3, lane_in_group) | ||
| ) | ||
| wt4 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v4, lane_in_group) | ||
| ) | ||
| wt5 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v5, lane_in_group) | ||
| ) | ||
| wt6 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v6, lane_in_group) | ||
| ) | ||
| wt7 = cute.local_tile( | ||
| h0_source, (1, 1, vec_size), (flat_write_idx, v7, lane_in_group) | ||
| ) | ||
| cute.autovec_copy(r_hb0, wt0) | ||
| cute.autovec_copy(r_hb1, wt1) | ||
| cute.autovec_copy(r_hb2, wt2) | ||
| cute.autovec_copy(r_hb3, wt3) | ||
| cute.autovec_copy(r_hb4, wt4) | ||
| cute.autovec_copy(r_hb5, wt5) | ||
| cute.autovec_copy(r_hb6, wt6) | ||
| cute.autovec_copy(r_hb7, wt7) |
There was a problem hiding this comment.
This block of code for writing the updated state back to global memory is very repetitive. It defines 8 separate tiles (wt0 to wt7) and performs 8 separate copy operations. This can be simplified and made more maintainable by using a loop.
You could consider refactoring the r_hb0 through r_hb7 register tensors into a single 2D tensor, e.g., r_hb = cute.make_rmem_tensor(cute.make_layout((ILP_ROWS, vec_size), ...), ...). Then, the conversion from r_h and the write-back could be done in loops, which would significantly reduce code duplication and improve readability. This pattern of repetition also exists in other parts of the kernel (e.g., loading h).
…l mode)
The T=1 ILP kernel (`gdn_decode_bf16state_ilp_kernel`) uses 4 warps × 32
threads with vec=4 BF16 per thread (LDG.E.64 / STG.E.64 on H state). NCU
captures at B=16/32 T=1 HV=64 show LSU pipe as the dominant bottleneck
(LSU 38.8–48.4 % vs DRAM 29.0–39.4 %), with L1/TEX at 67–75 %. The shipped
wide_vec kernel uses 8 groups × 16 threads / vec=8 BF16 (LDG.E.128 /
STG.E.128), which halves LSU instruction count and reduces L1 wavefronts.
wide_vec's SMEM-precompute phase runs `ceil(T / NUM_WARPS)` passes — at
T=1 that is 1 pass (only warp 0 does real work, others idle), so it
degenerates gracefully. Correctness: bit-exact at bf16 noise floor across
all 20 T=1 test configs (HV in {32, 64}, B in {1..512}, FP32 reference).
Changes:
1. gdn_decode_bf16_state.py::gated_delta_rule — new T=1 dispatch branch
placed BEFORE the small-batch (`B < ILP_BATCH_THRESHOLD`) MTP redirect
so B=8 HV=64 (work_units = 512) can also reach wide_vec at T=1 and get
the LDG.E.128 win (measured 1.05x at B=8). Gated by ALL of:
- K = V = 128 (wide_vec subgroup layout assumes this).
- Pool mode: `initial_state_indices is not None`. Non-pool direct-state
callers stay on the baseline ILP kernel, so the split-pool test path
(baseline) remains bit-exact with the non-pool reference path that
`test_output_state_indices` relies on.
- Single-pool write: `output_state_indices is None` or identical to
`initial_state_indices` (wide_vec has one indices tensor for R/W).
- tile_v >= 64. At T=1 the wide_vec Phase 0 SMEM-precompute overhead is
fixed per CTA while the main loop shrinks with tile_v. tile_v=32
gives only 1 ILP iter per subgroup, insufficient to amortize
Phase 0. Probe at HV=64 T=1 pool mode: tile_v=32 regresses at B=4
(0.91x); tile_v=64 wins at B=8 (1.05x); tile_v=128 wins at B>=16
(1.04-1.06x). So tile_v=32 is masked out at T=1 only.
2. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_t1_kernel`
parametrization extended to also cover num_v_heads=64 (production
Qwen3.5 shape). Total T=1 CI configs: 10 batch x 2 HV = 20.
Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, state-update ON, no cache,
pool mode with `initial_state_indices = arange(B)`, CUPTI 5 warmup + 50
bench iters). Baseline column monkey-patches `_select_wide_vec_tile_v`
to None so the ILP/MTP kernel path runs:
B baseline wide_vec T=1 speedup
1 3.49us 3.52us 0.99x (MTP ILP=4; wide_vec gate n/a)
2 4.29us 4.26us 1.01x (MTP ILP=4; wide_vec gate n/a)
4 5.76us 5.76us 1.00x (MTP ILP=8; tile_v=32 masked)
8 9.30us 8.83us 1.05x (NEW: wide_vec tile_v=64)
16 15.01us 14.35us 1.05x (wide_vec tile_v=128)
32 26.53us 25.30us 1.05x
64 48.45us 46.53us 1.04x
128 89.70us 86.00us 1.04x
256 172.22us 165.02us 1.04x
512 337.15us 323.10us 1.04x
Peak DRAM SOL at T=1 climbs from 80.6 % (previous ceiling) to 83.9 % at
B=512. All B >= 8 in pool mode gain 1.04-1.05x. Gains are smaller than
the T>=2 win (~1.10-1.50x) because the T=1 ILP kernel was already closer
to the DRAM roofline — LSU relief has less headroom to translate into
speedup.
Scope notes:
- Non-pool (direct-state) T=1 callers: stay on baseline ILP. Routing
them through wide_vec introduces 1-BF16-ULP output differences vs the
baseline-served split-pool path, which fails test_output_state_indices
at atol=1e-3. Since the test's semantic invariant is that split-pool
and non-pool paths produce bit-identical output, both must use the
same kernel; baseline ILP is the existing common denominator.
- Split-pool writes (distinct output_state_indices tensor): stay on
baseline ILP. wide_vec does not yet plumb h0_out_indices.
- B*HV < 512: falls through to the (relocated) small-batch MTP redirect
and baseline ILP=4/8 kernels. Wide_vec at tile_v=32 regresses here
because Phase 0 overhead isn't amortized.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
e7bee6a to
dc87e93
Compare
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes