Skip to content

Commit dc87e93

Browse files
ameynaik-hubclaude
andcommitted
perf(gdn): route T=1 decode through wide_vec at large work_units (pool mode)
The T=1 ILP kernel (`gdn_decode_bf16state_ilp_kernel`) uses 4 warps × 32 threads with vec=4 BF16 per thread (LDG.E.64 / STG.E.64 on H state). NCU captures at B=16/32 T=1 HV=64 show LSU pipe as the dominant bottleneck (LSU 38.8–48.4 % vs DRAM 29.0–39.4 %), with L1/TEX at 67–75 %. The shipped wide_vec kernel uses 8 groups × 16 threads / vec=8 BF16 (LDG.E.128 / STG.E.128), which halves LSU instruction count and reduces L1 wavefronts. wide_vec's SMEM-precompute phase runs `ceil(T / NUM_WARPS)` passes — at T=1 that is 1 pass (only warp 0 does real work, others idle), so it degenerates gracefully. Correctness: bit-exact at bf16 noise floor across all 20 T=1 test configs (HV in {32, 64}, B in {1..512}, FP32 reference). Changes: 1. gdn_decode_bf16_state.py::gated_delta_rule — new T=1 dispatch branch placed BEFORE the small-batch (`B < ILP_BATCH_THRESHOLD`) MTP redirect so B=8 HV=64 (work_units = 512) can also reach wide_vec at T=1 and get the LDG.E.128 win (measured 1.05x at B=8). Gated by ALL of: - K = V = 128 (wide_vec subgroup layout assumes this). - Pool mode: `initial_state_indices is not None`. Non-pool direct-state callers stay on the baseline ILP kernel, so the split-pool test path (baseline) remains bit-exact with the non-pool reference path that `test_output_state_indices` relies on. - Single-pool write: `output_state_indices is None` or identical to `initial_state_indices` (wide_vec has one indices tensor for R/W). - tile_v >= 64. At T=1 the wide_vec Phase 0 SMEM-precompute overhead is fixed per CTA while the main loop shrinks with tile_v. tile_v=32 gives only 1 ILP iter per subgroup, insufficient to amortize Phase 0. Probe at HV=64 T=1 pool mode: tile_v=32 regresses at B=4 (0.91x); tile_v=64 wins at B=8 (1.05x); tile_v=128 wins at B>=16 (1.04-1.06x). So tile_v=32 is masked out at T=1 only. 2. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_t1_kernel` parametrization extended to also cover num_v_heads=64 (production Qwen3.5 shape). Total T=1 CI configs: 10 batch x 2 HV = 20. Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, state-update ON, no cache, pool mode with `initial_state_indices = arange(B)`, CUPTI 5 warmup + 50 bench iters). Baseline column monkey-patches `_select_wide_vec_tile_v` to None so the ILP/MTP kernel path runs: B baseline wide_vec T=1 speedup 1 3.49us 3.52us 0.99x (MTP ILP=4; wide_vec gate n/a) 2 4.29us 4.26us 1.01x (MTP ILP=4; wide_vec gate n/a) 4 5.76us 5.76us 1.00x (MTP ILP=8; tile_v=32 masked) 8 9.30us 8.83us 1.05x (NEW: wide_vec tile_v=64) 16 15.01us 14.35us 1.05x (wide_vec tile_v=128) 32 26.53us 25.30us 1.05x 64 48.45us 46.53us 1.04x 128 89.70us 86.00us 1.04x 256 172.22us 165.02us 1.04x 512 337.15us 323.10us 1.04x Peak DRAM SOL at T=1 climbs from 80.6 % (previous ceiling) to 83.9 % at B=512. All B >= 8 in pool mode gain 1.04-1.05x. Gains are smaller than the T>=2 win (~1.10-1.50x) because the T=1 ILP kernel was already closer to the DRAM roofline — LSU relief has less headroom to translate into speedup. Scope notes: - Non-pool (direct-state) T=1 callers: stay on baseline ILP. Routing them through wide_vec introduces 1-BF16-ULP output differences vs the baseline-served split-pool path, which fails test_output_state_indices at atol=1e-3. Since the test's semantic invariant is that split-pool and non-pool paths produce bit-identical output, both must use the same kernel; baseline ILP is the existing common denominator. - Split-pool writes (distinct output_state_indices tensor): stay on baseline ILP. wide_vec does not yet plumb h0_out_indices. - B*HV < 512: falls through to the (relocated) small-batch MTP redirect and baseline ILP=4/8 kernels. Wide_vec at tile_v=32 regresses here because Phase 0 overhead isn't amortized. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
1 parent 5975885 commit dc87e93

2 files changed

Lines changed: 58 additions & 4 deletions

File tree

flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3077,9 +3077,63 @@ def gated_delta_rule(
30773077
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
30783078
output_state_indices = output_state_indices.to(torch.int32)
30793079

3080-
# Small batch: route through MTP kernel (T=1 path).
3081-
# The cooprow kernel has known correctness issues at small batch sizes (e.g. B=2).
3082-
# The MTP kernel's T=1 path uses the same ILP-style computation and is well-tested.
3080+
# Wide_vec T=1 fast path. Per NCU captures the T=1 ILP kernel is LSU-bound
3081+
# at B=16/32 (LSU > DRAM by ~10 pp, L1/TEX ~68 %); wide_vec's LDG.E.128 on
3082+
# H halves LSU traffic and reduces L1 wavefronts. Handled via wide_vec's
3083+
# SMEM-precompute phase (ceil(T/NUM_WARPS)=1 pass at T=1).
3084+
#
3085+
# Gates (all must hold):
3086+
# - K=V=128 (wide_vec subgroup layout assumes this).
3087+
# - Pool mode (initial_state_indices is not None). Non-pool direct-state
3088+
# callers stay on the baseline ILP kernel, so that the split-pool test
3089+
# path (baseline) stays bit-exact with the non-pool reference path.
3090+
# Split-pool isn't supported by wide_vec yet, so both must use baseline.
3091+
# - output_state_indices is None or identical to initial_state_indices
3092+
# (single-pool write; wide_vec has one indices tensor for R/W).
3093+
# - tile_v >= 64. At T=1 the wide_vec Phase 0 precompute overhead is
3094+
# fixed per CTA while the main loop shrinks with tile_v; tile_v=32
3095+
# gives only 1 ILP iter per subgroup, insufficient to amortize
3096+
# Phase 0. Measured at HV=64: tile_v=32 regresses at B=4 (0.91x);
3097+
# tile_v=64 wins at B=8 (1.05x). Evaluated ABOVE the small-batch
3098+
# MTP redirect so B=8 can hit the wide_vec path (B<16 otherwise
3099+
# redirects to `gated_delta_rule_mtp` which doesn't dispatch wide_vec
3100+
# at T=1).
3101+
_single_pool_write = (
3102+
output_state_indices is None or output_state_indices is initial_state_indices
3103+
)
3104+
wv_tile_v = (
3105+
_select_wide_vec_tile_v(B, HV, V)
3106+
if (K == 128 and initial_state_indices is not None and _single_pool_write)
3107+
else None
3108+
)
3109+
if wv_tile_v is not None and wv_tile_v < 64:
3110+
wv_tile_v = None # tile_v=32 at T=1 loses to baseline ILP
3111+
if wv_tile_v is not None:
3112+
from .gdn_decode_bf16_state_wide_vec import gated_delta_rule_mtp_wide_vec
3113+
3114+
return gated_delta_rule_mtp_wide_vec(
3115+
A_log=A_log,
3116+
a=a,
3117+
dt_bias=dt_bias,
3118+
softplus_beta=softplus_beta,
3119+
softplus_threshold=softplus_threshold,
3120+
q=q,
3121+
k=k,
3122+
v=v,
3123+
b=b,
3124+
initial_state_source=initial_state_source,
3125+
initial_state_indices=initial_state_indices,
3126+
intermediate_states_buffer=None, # T=1 has no cache
3127+
disable_state_update=False, # T=1 default: write final state
3128+
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
3129+
scale=scale,
3130+
output=output,
3131+
tile_v=wv_tile_v,
3132+
)
3133+
3134+
# Small-batch redirect: wide_vec didn't trigger (B*HV too small or non-pool
3135+
# call). Route through the MTP kernel's T=1 path — cooprow has known
3136+
# correctness issues at small batch.
30833137
if B < ILP_BATCH_THRESHOLD:
30843138
return gated_delta_rule_mtp(
30853139
A_log=A_log,

tests/gdn/test_decode_delta_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1742,7 +1742,7 @@ def _test_gdn_decode_bf16_state_t1_kernel(
17421742
@pytest.mark.parametrize("head_size", [128])
17431743
@pytest.mark.parametrize(
17441744
"num_q_heads, num_k_heads, num_v_heads",
1745-
[(16, 16, 32)],
1745+
[(16, 16, 32), (16, 16, 64)],
17461746
)
17471747
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512])
17481748
@pytest.mark.parametrize("dtype", ["bfloat16"])

0 commit comments

Comments
 (0)