webgpu: Normalize indirect-dispatch buffer in flash attention#28802
Open
qjia7 wants to merge 1 commit into
Open
webgpu: Normalize indirect-dispatch buffer in flash attention#28802qjia7 wants to merge 1 commit into
qjia7 wants to merge 1 commit into
Conversation
The host-side dispatch path uses ProgramManager::NormalizeDispatchGroupSize to split a 1D group count into a 2D layout when it would exceed the WebGPU per-dimension limit (maxComputeWorkgroupsPerDimension, 65535 per the spec). The indirect-dispatch path, where the dispatch size is computed on the GPU, skipped this normalization, leaving a "TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size." in CopyKVCacheProgram. Add a shared WGSL helper write_indirect_dispatch(total) that mirrors the host normalizer's 2D (sqrt) tier and emit it via AdditionalImplementation() from both CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKV Program. The helper writes (total, 1, 1) when total <= 65535, otherwise (ceil(sqrt(total)), ceil(sqrt(total)), 1). Consumer shaders are unchanged: shader_helper always flattens workgroup_id (x, y, z) into a linear workgroup_idx, so the 1D-vs-2D split is transparent to consumers. The indirect-dispatch total is now structurally aligned with the direct- dispatch formula at the FlashAttentionDecodeQKT call site: batch_size * num_heads * num_total_seq_length_tile. batch_size is read from copy_kv_shape[0] / query_shape[0] without introducing a new uniform. Verified on phi4-graph-prune: verify_model_correctness.py (3/3 queries match reference) and verify_multi_gen.py (sequential and overlapping multi-generator runs all coherent), both under the production threshold (65535) and a forced 2D-split threshold (1) used for testing.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
write_indirect_dispatch(total)inflash_attention.ccthat mirrorsProgramManager::NormalizeDispatchGroupSize's 2D (sqrt) tier on the GPU.CopyKVCacheProgramandSplitPackedQKVWithRotaryEmbeddingAndCopyKVProgramviaAdditionalImplementation(); both now write a normalized 1D-or-2D dispatch triple instead of an inline 3D(tiles, num_heads, 1)triple.totalstructurally with the direct-dispatch formula at theFlashAttentionDecodeQKTcall site:batch_size * num_heads * num_total_seq_length_tile.batch_sizeis read from existing in-scope shape uniforms (copy_kv_shape[0]/query_shape[0]) — no new uniform introduced.Motivation
CopyKVCacheProgram::GenerateShaderCodecarried// TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size.because the indirect-dispatch path (where the dispatch size is computed on the GPU) bypassed the host'sNormalizeDispatchGroupSize1D-to-2D split. For sufficiently largetotal, the indirect dispatch would exceedmaxComputeWorkgroupsPerDimension(65535 per the WebGPU spec) and the workgroup-grid call would be invalid.Replicating the host normalizer on the GPU resolves the TODO without changing any consumer shaders:
shader_helperalready flattensworkgroup_id (x, y, z)into a single linearworkgroup_idx, so the 1D-vs-2D split is transparent to FlashAttentionDecodeQKT / FlashAttentionDecodeSplitVxScore.Additionally, the previous inline writes used
(num_total_seq_length_tile, num_heads, 1)— missing thebatch_sizefactor that the direct-dispatch path uses (SetDispatchGroupSize(batch_size * num_heads * num_total_seq_length_tile)at line 307). This commit corrects that alignment so the two paths produce identical workgroup counts.Scope note: the new WGSL helper covers the 2D tier only. The host helper additionally falls back to a 3D (cbrt) layout for
total > ~65535^2. The 2D-only mirror is safe up to ~4.29B, far beyond any realistic attention workload.Test plan
cmake --build ... --target onnxruntime --config Releaseclean.grep -aoE 'uniforms\.copy_kv_shape_shape\[0\]|uniforms\.query_shape\[0\] \* uniforms\.num_heads|let limit = [0-9]+u'confirms the new shader strings are compiled into the deployed DLL withlimit = 65535u.phi4-graph-prune: 3/3 queries match reference (math, factual, math).phi4-graph-prune: 3 sequential prompts produce identical output across runs, 2 overlapping generators produce coherent independent output.let limit = 1u(verified the binary was actually rebuilt with that value); both verification scripts still PASS, confirming the 2D-split path is correct.lintrunner -a: clean, no auto-applied changes.