Commit dc87e93
perf(gdn): route T=1 decode through wide_vec at large work_units (pool mode)
The T=1 ILP kernel (`gdn_decode_bf16state_ilp_kernel`) uses 4 warps × 32
threads with vec=4 BF16 per thread (LDG.E.64 / STG.E.64 on H state). NCU
captures at B=16/32 T=1 HV=64 show LSU pipe as the dominant bottleneck
(LSU 38.8–48.4 % vs DRAM 29.0–39.4 %), with L1/TEX at 67–75 %. The shipped
wide_vec kernel uses 8 groups × 16 threads / vec=8 BF16 (LDG.E.128 /
STG.E.128), which halves LSU instruction count and reduces L1 wavefronts.
wide_vec's SMEM-precompute phase runs `ceil(T / NUM_WARPS)` passes — at
T=1 that is 1 pass (only warp 0 does real work, others idle), so it
degenerates gracefully. Correctness: bit-exact at bf16 noise floor across
all 20 T=1 test configs (HV in {32, 64}, B in {1..512}, FP32 reference).
Changes:
1. gdn_decode_bf16_state.py::gated_delta_rule — new T=1 dispatch branch
placed BEFORE the small-batch (`B < ILP_BATCH_THRESHOLD`) MTP redirect
so B=8 HV=64 (work_units = 512) can also reach wide_vec at T=1 and get
the LDG.E.128 win (measured 1.05x at B=8). Gated by ALL of:
- K = V = 128 (wide_vec subgroup layout assumes this).
- Pool mode: `initial_state_indices is not None`. Non-pool direct-state
callers stay on the baseline ILP kernel, so the split-pool test path
(baseline) remains bit-exact with the non-pool reference path that
`test_output_state_indices` relies on.
- Single-pool write: `output_state_indices is None` or identical to
`initial_state_indices` (wide_vec has one indices tensor for R/W).
- tile_v >= 64. At T=1 the wide_vec Phase 0 SMEM-precompute overhead is
fixed per CTA while the main loop shrinks with tile_v. tile_v=32
gives only 1 ILP iter per subgroup, insufficient to amortize
Phase 0. Probe at HV=64 T=1 pool mode: tile_v=32 regresses at B=4
(0.91x); tile_v=64 wins at B=8 (1.05x); tile_v=128 wins at B>=16
(1.04-1.06x). So tile_v=32 is masked out at T=1 only.
2. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_t1_kernel`
parametrization extended to also cover num_v_heads=64 (production
Qwen3.5 shape). Total T=1 CI configs: 10 batch x 2 HV = 20.
Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, state-update ON, no cache,
pool mode with `initial_state_indices = arange(B)`, CUPTI 5 warmup + 50
bench iters). Baseline column monkey-patches `_select_wide_vec_tile_v`
to None so the ILP/MTP kernel path runs:
B baseline wide_vec T=1 speedup
1 3.49us 3.52us 0.99x (MTP ILP=4; wide_vec gate n/a)
2 4.29us 4.26us 1.01x (MTP ILP=4; wide_vec gate n/a)
4 5.76us 5.76us 1.00x (MTP ILP=8; tile_v=32 masked)
8 9.30us 8.83us 1.05x (NEW: wide_vec tile_v=64)
16 15.01us 14.35us 1.05x (wide_vec tile_v=128)
32 26.53us 25.30us 1.05x
64 48.45us 46.53us 1.04x
128 89.70us 86.00us 1.04x
256 172.22us 165.02us 1.04x
512 337.15us 323.10us 1.04x
Peak DRAM SOL at T=1 climbs from 80.6 % (previous ceiling) to 83.9 % at
B=512. All B >= 8 in pool mode gain 1.04-1.05x. Gains are smaller than
the T>=2 win (~1.10-1.50x) because the T=1 ILP kernel was already closer
to the DRAM roofline — LSU relief has less headroom to translate into
speedup.
Scope notes:
- Non-pool (direct-state) T=1 callers: stay on baseline ILP. Routing
them through wide_vec introduces 1-BF16-ULP output differences vs the
baseline-served split-pool path, which fails test_output_state_indices
at atol=1e-3. Since the test's semantic invariant is that split-pool
and non-pool paths produce bit-identical output, both must use the
same kernel; baseline ILP is the existing common denominator.
- Split-pool writes (distinct output_state_indices tensor): stay on
baseline ILP. wide_vec does not yet plumb h0_out_indices.
- B*HV < 512: falls through to the (relocated) small-batch MTP redirect
and baseline ILP=4/8 kernels. Wide_vec at tile_v=32 regresses here
because Phase 0 overhead isn't amortized.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>1 parent 5975885 commit dc87e93
2 files changed
Lines changed: 58 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3077 | 3077 | | |
3078 | 3078 | | |
3079 | 3079 | | |
3080 | | - | |
3081 | | - | |
3082 | | - | |
| 3080 | + | |
| 3081 | + | |
| 3082 | + | |
| 3083 | + | |
| 3084 | + | |
| 3085 | + | |
| 3086 | + | |
| 3087 | + | |
| 3088 | + | |
| 3089 | + | |
| 3090 | + | |
| 3091 | + | |
| 3092 | + | |
| 3093 | + | |
| 3094 | + | |
| 3095 | + | |
| 3096 | + | |
| 3097 | + | |
| 3098 | + | |
| 3099 | + | |
| 3100 | + | |
| 3101 | + | |
| 3102 | + | |
| 3103 | + | |
| 3104 | + | |
| 3105 | + | |
| 3106 | + | |
| 3107 | + | |
| 3108 | + | |
| 3109 | + | |
| 3110 | + | |
| 3111 | + | |
| 3112 | + | |
| 3113 | + | |
| 3114 | + | |
| 3115 | + | |
| 3116 | + | |
| 3117 | + | |
| 3118 | + | |
| 3119 | + | |
| 3120 | + | |
| 3121 | + | |
| 3122 | + | |
| 3123 | + | |
| 3124 | + | |
| 3125 | + | |
| 3126 | + | |
| 3127 | + | |
| 3128 | + | |
| 3129 | + | |
| 3130 | + | |
| 3131 | + | |
| 3132 | + | |
| 3133 | + | |
| 3134 | + | |
| 3135 | + | |
| 3136 | + | |
3083 | 3137 | | |
3084 | 3138 | | |
3085 | 3139 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1742 | 1742 | | |
1743 | 1743 | | |
1744 | 1744 | | |
1745 | | - | |
| 1745 | + | |
1746 | 1746 | | |
1747 | 1747 | | |
1748 | 1748 | | |
| |||
0 commit comments