Skip to content

Commit f7fd3b5

Browse files
authored
[webgpu] Register GQA based on graph capture (microsoft#26384)
This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to microsoft#25868). This is the last functionality part to support graph capture in webgpu ep in ORT. The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models. In this PR, we still get total sequence length from `seqlen_k` tensor not `total_seqlen_tensor` tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use `total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
1 parent 3a6a4c2 commit f7fd3b5

File tree

9 files changed

+104
-57
lines changed

9 files changed

+104
-57
lines changed

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,14 @@ Status CheckInputs(const T* query,
251251
"seqlens_k must be shape (batch_size).");
252252
}
253253

254-
// Set present sequence length from input total_seqlen tensor
255254
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
256255
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
257256
"total_sequence_length tensor must be of one element.");
258257
}
259-
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
258+
259+
// When graph capture is enabled, total_seqlen is on GPU and cannot be read. Skip validation.
260+
const bool is_total_seqlen_on_cpu = (total_seqlen->Location().device.Type() == OrtDevice::CPU);
261+
int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data<int32_t>()) : 0;
260262
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
261263

262264
int rotary_dim = 0;
@@ -267,22 +269,20 @@ Status CheckInputs(const T* query,
267269
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
268270
}
269271

272+
// Skip prompt type detection when total_seqlen is on GPU (graph capture mode)
270273
bool is_subsequent_prompt = false;
271-
if (sequence_length > 1 && sequence_length != total_sequence_length) {
272-
if (batch_size != 1) {
273-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
274-
"batch_size must be 1 when sequence_length > 1 and past context is given.");
274+
bool is_first_prompt = false;
275+
if (is_total_seqlen_on_cpu) {
276+
if (sequence_length > 1 && sequence_length != total_sequence_length) {
277+
if (batch_size != 1) {
278+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
279+
"batch_size must be 1 when sequence_length > 1 and past context is given.");
280+
}
281+
is_subsequent_prompt = true;
275282
}
276-
is_subsequent_prompt = true;
277-
}
278283

279-
bool is_first_prompt;
280-
if (is_subsequent_prompt) {
281-
is_first_prompt = false; // irrelevant for interactive decoding
282-
} else {
283-
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
284284
is_first_prompt = (sequence_length == total_sequence_length);
285-
if (!is_first_prompt && sequence_length != 1) {
285+
if (!is_subsequent_prompt && !is_first_prompt && sequence_length != 1) {
286286
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
287287
"sequence_length shall be 1 when it is not prompt.");
288288
}

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
3131
const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
3232
const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform);
3333
const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape");
34+
if (use_seqlen_k_) {
35+
shader.AddInput("seqlen_k", ShaderUsage::None);
36+
}
3437
// If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
3538
if (prepare_indirect_dispatch_) {
36-
shader.AddInput("seqlen_k", ShaderUsage::None);
3739
shader.AddOutput("indirect_buffer", ShaderUsage::None);
3840
}
3941

@@ -43,7 +45,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
4345
" let sequence_id = output_indices[2];\n"
4446
" let num_head_id = output_indices[1];\n"
4547
" let batch = output_indices[0];\n";
46-
if (prepare_indirect_dispatch_) {
48+
if (use_seqlen_k_) {
4749
shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n";
4850
} else {
4951
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
@@ -105,9 +107,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
105107

106108
// Determine if we need to prepare indirect dispatch
107109
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
110+
bool use_seqlen_k = (seqlen_k != nullptr);
108111

109112
CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH,
110-
prepare_indirect_dispatch};
113+
prepare_indirect_dispatch, use_seqlen_k};
111114
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
112115
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
113116
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
@@ -119,7 +122,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
119122
{V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}});
120123
}
121124

122-
if (prepare_indirect_dispatch) {
125+
if (use_seqlen_k) {
123126
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
124127
}
125128

@@ -137,7 +140,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
137140
program.AddIndices(std::move(copy_kv_shape));
138141
program.SetDispatchGroupSize(static_cast<uint32_t>((copy_size + 63) / 64))
139142
.SetWorkgroupSize(64)
140-
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch)
143+
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k)
141144
.AddUniformVariables({{static_cast<uint32_t>(copy_size)},
142145
{static_cast<uint32_t>(parameters.total_sequence_length_)},
143146
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
@@ -167,6 +170,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
167170
if (has_attention_bias_) {
168171
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
169172
}
173+
if (use_seqlen_k_) {
174+
shader.AddInput("seqlens_k", ShaderUsage::None);
175+
}
170176
shader.AddOutput("output", ShaderUsage::UseUniform);
171177

172178
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template",
@@ -176,7 +182,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
176182
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
177183
WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_),
178184
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
179-
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_));
185+
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
186+
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
180187
}
181188

182189
Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
@@ -349,10 +356,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
349356
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
350357
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);
351358

359+
const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();
360+
352361
if (parameters.sequence_length_ > 1) {
353362
const uint32_t tile_size = 64;
354363
// For encode path, use the original CopyKVCache without indirect dispatch preparation
355-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
364+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
356365
bool has_attention_bias = attention_bias != nullptr;
357366
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
358367
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -364,24 +373,27 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
364373
parameters.head_size_,
365374
parameters.num_heads_,
366375
parameters.is_unidirectional_,
367-
is_nvidia};
376+
is_nvidia,
377+
use_seqlen_k};
368378
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
369379
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
370380
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}});
371381
if (has_attention_bias) {
372382
program.AddInputs({{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}});
373383
}
384+
if (use_seqlen_k) {
385+
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::None}});
386+
}
374387
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}});
375388
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
376389
: parameters.scale_;
377390
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
378391
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
379392
.SetWorkgroupSize(tile_size)
380-
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia)
393+
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
381394
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
382395
{static_cast<uint32_t>(parameters.total_sequence_length_)},
383396
{static_cast<uint32_t>(present_sequence_length)},
384-
{static_cast<uint32_t>(parameters.total_sequence_length_ - parameters.kv_sequence_length_)},
385397
{static_cast<uint32_t>(parameters.n_reps)},
386398
{alpha},
387399
{num_seq_tile}});

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ using namespace onnxruntime::webgpu;
1818
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
1919
public:
2020
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
21-
bool prepare_indirect_dispatch = false)
22-
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch) {
21+
bool prepare_indirect_dispatch = false, bool use_seqlen_k = false)
22+
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
2323
}
2424

2525
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -34,6 +34,7 @@ class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
3434
bool has_past_;
3535
bool kv_BNSH_;
3636
bool prepare_indirect_dispatch_;
37+
bool use_seqlen_k_;
3738
};
3839

3940
class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
@@ -45,23 +46,24 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
4546
int qkv_head_size,
4647
int qkv_num_heads,
4748
bool is_unidirectional,
48-
bool is_nvidia)
49+
bool is_nvidia,
50+
bool use_seqlen_k = false)
4951
: Program{kernel_name},
5052
has_attention_bias_(has_attention_bias),
5153
is_qualcomm_(is_qualcomm),
5254
is_fp16_(is_fp16),
5355
qkv_head_size_(qkv_head_size),
5456
qkv_num_heads_(qkv_num_heads),
5557
is_unidirectional_(is_unidirectional),
56-
is_nvidia_(is_nvidia) {
58+
is_nvidia_(is_nvidia),
59+
use_seqlen_k_(use_seqlen_k) {
5760
}
5861

5962
Status GenerateShaderCode(ShaderHelper& sh) const override;
6063

6164
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
6265
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
6366
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
64-
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
6567
{"n_reps", ProgramUniformVariableDataType::Uint32},
6668
{"alpha", ProgramUniformVariableDataType::Float32},
6769
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
@@ -74,6 +76,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
7476
int qkv_num_heads_;
7577
bool is_unidirectional_;
7678
bool is_nvidia_;
79+
bool use_seqlen_k_;
7780
};
7881

7982
class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {

onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,23 @@
66
#param prefer_subgroupshuffle
77
#param qkv_head_size
88
#param qkv_num_heads
9+
#param use_seqlen_k
910

1011
const head_size : u32 = qkv_head_size;
1112
const num_heads : u32 = qkv_num_heads;
1213

14+
#if use_seqlen_k
15+
// When graph capture is enabled, total_sequence_length is read from GPU buffer
16+
fn get_total_sequence_length() -> u32 {
17+
return u32(seqlens_k[0]) + 1u;
18+
}
19+
#else
20+
// When graph capture is disabled, total_sequence_length comes from uniforms
21+
fn get_total_sequence_length() -> u32 {
22+
return uniforms.total_sequence_length;
23+
}
24+
#endif
25+
1326
#if is_fp16
1427
const min_value = q_element_t(-65504.0);
1528
#else
@@ -45,17 +58,17 @@ fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) {
4558
let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec;
4659
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
4760
let slot = u32(idx / head_size_vec);
48-
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < uniforms.total_sequence_length);
61+
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length());
4962
k_tile[slot][idx % head_size_vec] = val;
5063
}
5164
}
5265

53-
fn loadv(v_start : u32, head_idx : u32, local_idx : u32, k_step : u32) {
66+
fn loadv(v_start : u32, head_idx : u32, local_idx : u32, v_step : u32) {
5467
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
5568
let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec;
56-
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
69+
for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) {
5770
let slot = u32(idx / head_size_vec);
58-
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < uniforms.total_sequence_length);
71+
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length());
5972
v_tile[slot][idx % head_size_vec] = val;
6073
}
6174
}
@@ -93,12 +106,12 @@ fn writeo(o_idx_global : u32, head_idx : u32) {
93106
#if has_attention_bias
94107
fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4<q_element_t> {
95108
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
96-
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.total_sequence_length) {
109+
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= get_total_sequence_length()) {
97110
return vec4<q_element_t>(0);
98111
}
99-
let offset_base = head_idx * uniforms.new_sequence_length * uniforms.total_sequence_length + q_idx_global * uniforms.total_sequence_length;
112+
let offset_base = head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length();
100113
let offset = offset_base + k_idx_global;
101-
let offset_max = offset_base + uniforms.total_sequence_length;
114+
let offset_max = offset_base + get_total_sequence_length();
102115
let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
103116
let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]);
104117
let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]);
@@ -141,16 +154,18 @@ $MAIN {
141154

142155
var previous_max : q_element_t = min_value;
143156
var previous_denom : q_element_t = 0;
157+
let total_sequence_length = get_total_sequence_length();
144158

145159
#if is_unidirectional
146160
// If attention is unidirectional, set the loop bound to enforce causal masking.
147-
let max_causal_len_for_workgroup = uniforms.past_sequence_length +
161+
let past_sequence_length = total_sequence_length - uniforms.new_sequence_length;
162+
let max_causal_len_for_workgroup = past_sequence_length +
148163
(workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x;
149-
let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup);
150-
let seq_causal_length = uniforms.past_sequence_length + q_idx_global + 1;
164+
let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup);
165+
let seq_causal_length = past_sequence_length + q_idx_global + 1;
151166
#else
152-
let loop_bound = uniforms.total_sequence_length;
153-
let seq_causal_length = uniforms.total_sequence_length;
167+
let loop_bound = total_sequence_length;
168+
let seq_causal_length = total_sequence_length;
154169
#endif
155170

156171
for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) {

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ namespace onnxruntime {
1919
namespace contrib {
2020
namespace webgpu {
2121

22-
ONNX_OPERATOR_KERNEL_EX(
23-
GroupQueryAttention,
24-
kMSDomain,
25-
1,
26-
kWebGpuExecutionProvider,
27-
(*KernelDefBuilder::Create())
28-
.TypeConstraint("T", WebGpuSupportedFloatTypes())
29-
.MayInplace(3, 1)
30-
.MayInplace(4, 2)
31-
.InputMemoryType(OrtMemTypeCPUInput, 6),
32-
GroupQueryAttention);
33-
3422
Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
3523
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
3624
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
@@ -270,6 +258,29 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
270258
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
271259
}
272260

261+
KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture) {
262+
KernelDefBuilder builder;
263+
builder.SetName("GroupQueryAttention")
264+
.SetDomain(kMSDomain)
265+
.SinceVersion(1)
266+
.Provider(kWebGpuExecutionProvider)
267+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
268+
.MayInplace(3, 1)
269+
.MayInplace(4, 2);
270+
271+
// Only set InputMemoryType to CPU when graph capture is disabled
272+
if (!enable_graph_capture) {
273+
builder.InputMemoryType(OrtMemTypeCPUInput, 6);
274+
}
275+
276+
return KernelCreateInfo(
277+
builder.Build(),
278+
[](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
279+
out = std::make_unique<GroupQueryAttention>(info);
280+
return Status::OK();
281+
});
282+
}
283+
273284
} // namespace webgpu
274285
} // namespace contrib
275286
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class GroupQueryAttention final : public WebGpuKernel {
5858
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
5959
};
6060

61+
KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture);
62+
6163
} // namespace webgpu
6264
} // namespace contrib
6365
} // namespace onnxruntime

0 commit comments

Comments
 (0)