Skip to content

WebGPU: Add indirect dispatch for flash attention graph capture#29236

Draft
feich-ms wants to merge 6 commits into
mainfrom
user/feich/gemma4_webgpu_gc_support
Draft

WebGPU: Add indirect dispatch for flash attention graph capture#29236
feich-ms wants to merge 6 commits into
mainfrom
user/feich/gemma4_webgpu_gc_support

Conversation

@feich-ms

@feich-ms feich-ms commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Adds `PrepareIndirectDispatchProgram` to compute dispatch group sizes on GPU from `seqlen_k` when total_sequence_length is unavailable on CPU
  • Extends indirect dispatch to kv_empty layers (shared KV layers where `total_sequence_length_=0` would cause `dispatch(0)` crashes)
  • Enables graph capture for models with dynamic sequence lengths (e.g., Gemma4), achieving ~25% decode throughput improvement on WebGPU

Changes

  • `flash_attention.cc`:
    • New `PrepareIndirectDispatchProgram::GenerateShaderCode` for GPU-side dispatch sizing on kv_empty layers
    • Extracted `AppendNormalizeDispatchShader()` helper shared by `CopyKVCacheProgram` and `PrepareIndirectDispatchProgram` to eliminate duplicated tile-count + `normalize_dispatch_group_size` logic
    • `use_indirect_dispatch` expressed as `use_seqlen_k && (share_buffer || kv_empty)` to make the subset relationship explicit
    • Tightened `use_seqlen_k` / `use_indirect_dispatch` guards from `== 0` to `<= 0` for defensive handling of negative `total_sequence_length_`
  • `flash_attention.h`: Added `PrepareIndirectDispatchProgram` class declaration

Test plan

  • Verified with Gemma4 INT4 model: ~95-105 tok/s with GC=ON vs ~75 tok/s with GC=OFF
  • Tested multiple prompts (short/long) with graph capture enabled — no crashes, correct output
  • Verified warmup (capture run) followed by replay produces consistent results

@feich-ms feich-ms requested a review from Copilot June 26, 2026 03:19
@feich-ms feich-ms added the ep:WebGPU ort-web webgpu provider label Jun 26, 2026
@feich-ms feich-ms marked this pull request as ready for review June 26, 2026 03:23

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves the WebGPU FlashAttention decode path’s compatibility with graph capture and dynamic sequence lengths by enabling GPU-side computation of indirect dispatch group sizes from seqlen_k (including the kv_empty/shared-KV case that previously could produce dispatch(0)).

Changes:

  • Added a dedicated PrepareIndirectDispatchProgram shader to compute normalized indirect dispatch dimensions on GPU from seqlen_k.
  • Expanded use_seqlen_k / use_indirect_dispatch gating so indirect dispatch is also available when total_sequence_length_ is unavailable on CPU (including kv_empty layers).
  • Ensured the kv_empty path prepares the indirect dispatch buffer even when CopyKVCache is skipped.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Declares PrepareIndirectDispatchProgram and its uniform interface.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements the new program and wires indirect-dispatch preparation into the kv_empty path and broadened gating logic.

@feich-ms feich-ms requested a review from qjia7 June 26, 2026 08:02
feich-ms and others added 3 commits June 26, 2026 16:23
- Extract AppendNormalizeDispatchShader() helper so CopyKVCacheProgram
  and PrepareIndirectDispatchProgram share one copy of the tile-count
  and normalize_dispatch_group_size call instead of duplicating it.
- Express use_indirect_dispatch as use_seqlen_k && (share_buffer || kv_empty)
  to make the subset relationship explicit and eliminate the repeated condition.
- Tighten use_seqlen_k / use_indirect_dispatch guards from == 0 to <= 0
  to handle a negative total_sequence_length_ defensively.
- Add comment on the WGSL template's normalize call pointing back to the
  C++ helper so the two stay in sync.

Co-Authored-By: Claude <noreply@anthropic.com>
Add two WebGPU GQA tests that exercise PrepareIndirectDispatchProgram:
- WebGPU_SharedKV_IndirectDispatch_Decode: kv_empty + total_sequence_length=0
  (decode, past_seq=8), triggers use_seqlen_k=true and use_indirect_dispatch=true
  via the kv_empty path, cross-checked against CPU reference.
- WebGPU_SharedKV_IndirectDispatch_LargerPast: same path with past_seq=32 to
  exercise num_total_seq_length_tile > 1 in the tile count arithmetic.

Co-Authored-By: Claude <noreply@anthropic.com>
@feich-ms feich-ms force-pushed the user/feich/gemma4_webgpu_gc_support branch from 6ae8f1f to 63159eb Compare June 26, 2026 08:31

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
feich-ms and others added 3 commits June 26, 2026 16:55
Co-Authored-By: Claude <noreply@anthropic.com>
Replace deprecated bool use_cuda/use_webgpu params with GqaTargetEp::kCpu.

Co-Authored-By: Claude <noreply@anthropic.com>
OpTester cannot enable graph capture, so use_indirect_dispatch is never
triggered. Rewrite the tests to exercise the kv_empty path directly with
a real positive total_sequence_length instead of 0.

Co-Authored-By: Claude <noreply@anthropic.com>
@feich-ms feich-ms marked this pull request as draft June 26, 2026 11:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants