Skip to content

webgpu: Normalize indirect-dispatch buffer in flash attention#28802

Open
qjia7 wants to merge 1 commit into
mainfrom
webgpu-shared-write-indirect-dispatch
Open

webgpu: Normalize indirect-dispatch buffer in flash attention#28802
qjia7 wants to merge 1 commit into
mainfrom
webgpu-shared-write-indirect-dispatch

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented Jun 5, 2026

Summary

  • Add shared WGSL helper write_indirect_dispatch(total) in flash_attention.cc that mirrors ProgramManager::NormalizeDispatchGroupSize's 2D (sqrt) tier on the GPU.
  • Emit the helper from CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram via AdditionalImplementation(); both now write a normalized 1D-or-2D dispatch triple instead of an inline 3D (tiles, num_heads, 1) triple.
  • Align the indirect-dispatch total structurally with the direct-dispatch formula at the FlashAttentionDecodeQKT call site: batch_size * num_heads * num_total_seq_length_tile. batch_size is read from existing in-scope shape uniforms (copy_kv_shape[0] / query_shape[0]) — no new uniform introduced.

Motivation

CopyKVCacheProgram::GenerateShaderCode carried // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size. because the indirect-dispatch path (where the dispatch size is computed on the GPU) bypassed the host's NormalizeDispatchGroupSize 1D-to-2D split. For sufficiently large total, the indirect dispatch would exceed maxComputeWorkgroupsPerDimension (65535 per the WebGPU spec) and the workgroup-grid call would be invalid.

Replicating the host normalizer on the GPU resolves the TODO without changing any consumer shaders: shader_helper already flattens workgroup_id (x, y, z) into a single linear workgroup_idx, so the 1D-vs-2D split is transparent to FlashAttentionDecodeQKT / FlashAttentionDecodeSplitVxScore.

Additionally, the previous inline writes used (num_total_seq_length_tile, num_heads, 1) — missing the batch_size factor that the direct-dispatch path uses (SetDispatchGroupSize(batch_size * num_heads * num_total_seq_length_tile) at line 307). This commit corrects that alignment so the two paths produce identical workgroup counts.

Scope note: the new WGSL helper covers the 2D tier only. The host helper additionally falls back to a 3D (cbrt) layout for total > ~65535^2. The 2D-only mirror is safe up to ~4.29B, far beyond any realistic attention workload.

Test plan

  • Build: incremental cmake --build ... --target onnxruntime --config Release clean.
  • DLL binary inspection: grep -aoE 'uniforms\.copy_kv_shape_shape\[0\]|uniforms\.query_shape\[0\] \* uniforms\.num_heads|let limit = [0-9]+u' confirms the new shader strings are compiled into the deployed DLL with limit = 65535u.
  • Single-generator correctness on phi4-graph-prune: 3/3 queries match reference (math, factual, math).
  • Multi-generator correctness on phi4-graph-prune: 3 sequential prompts produce identical output across runs, 2 overlapping generators produce coherent independent output.
  • Forced 2D-split exercise: temporarily ran with let limit = 1u (verified the binary was actually rebuilt with that value); both verification scripts still PASS, confirming the 2D-split path is correct.
  • lintrunner -a: clean, no auto-applied changes.

The host-side dispatch path uses ProgramManager::NormalizeDispatchGroupSize
to split a 1D group count into a 2D layout when it would exceed the WebGPU
per-dimension limit (maxComputeWorkgroupsPerDimension, 65535 per the spec).
The indirect-dispatch path, where the dispatch size is computed on the GPU,
skipped this normalization, leaving a "TODO: Add NormalizeDispatchGroupSize
logic here to avoid exceeding max dispatch size." in CopyKVCacheProgram.

Add a shared WGSL helper write_indirect_dispatch(total) that mirrors the
host normalizer's 2D (sqrt) tier and emit it via AdditionalImplementation()
from both CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKV
Program. The helper writes (total, 1, 1) when total <= 65535, otherwise
(ceil(sqrt(total)), ceil(sqrt(total)), 1). Consumer shaders are unchanged:
shader_helper always flattens workgroup_id (x, y, z) into a linear
workgroup_idx, so the 1D-vs-2D split is transparent to consumers.

The indirect-dispatch total is now structurally aligned with the direct-
dispatch formula at the FlashAttentionDecodeQKT call site:
batch_size * num_heads * num_total_seq_length_tile. batch_size is read
from copy_kv_shape[0] / query_shape[0] without introducing a new uniform.

Verified on phi4-graph-prune: verify_model_correctness.py (3/3 queries
match reference) and verify_multi_gen.py (sequential and overlapping
multi-generator runs all coherent), both under the production threshold
(65535) and a forced 2D-split threshold (1) used for testing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant