Skip to content

Commit 63159eb

Browse files
feich-msclaude
andcommitted
webgpu: add indirect dispatch path tests for flash attention
Add two WebGPU GQA tests that exercise PrepareIndirectDispatchProgram: - WebGPU_SharedKV_IndirectDispatch_Decode: kv_empty + total_sequence_length=0 (decode, past_seq=8), triggers use_seqlen_k=true and use_indirect_dispatch=true via the kv_empty path, cross-checked against CPU reference. - WebGPU_SharedKV_IndirectDispatch_LargerPast: same path with past_seq=32 to exercise num_total_seq_length_tile > 1 in the tile count arithmetic. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 1933251 commit 63159eb

1 file changed

Lines changed: 133 additions & 2 deletions

File tree

onnxruntime/test/contrib_ops/group_query_attention_op_test.cc

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,8 +2819,7 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) {
28192819
RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kCuda);
28202820
}
28212821

2822-
TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) {
2823-
auto webgpu_ep = DefaultWebGpuExecutionProvider();
2822+
TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { auto webgpu_ep = DefaultWebGpuExecutionProvider();
28242823
if (!webgpu_ep) {
28252824
GTEST_SKIP() << "WebGPU EP not available";
28262825
}
@@ -2872,5 +2871,137 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_W
28722871
RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true);
28732872
}
28742873

2874+
// ---------------------------------------------------------------------------
2875+
// Indirect dispatch tests (PrepareIndirectDispatchProgram)
2876+
//
2877+
// These tests pass total_sequence_length=0 to trigger the kv_empty indirect
2878+
// dispatch path: use_seqlen_k=true so seqlen_k tensor is used on the GPU, and
2879+
// use_indirect_dispatch=true so PrepareIndirectDispatchProgram sizes the flash-
2880+
// attention dispatch from the GPU-resident seqlen_k value instead of the zero
2881+
// CPU value. Correctness is verified by cross-checking against CPU with the
2882+
// real total.
2883+
// ---------------------------------------------------------------------------
2884+
2885+
// WebGPU: kv_empty + total_sequence_length=0, decode (q_seq=1).
2886+
// Exercises PrepareIndirectDispatchProgram for the kv_empty indirect dispatch path.
2887+
TEST(GroupQueryAttentionTest, WebGPU_SharedKV_IndirectDispatch_Decode) {
2888+
auto webgpu_ep = DefaultWebGpuExecutionProvider();
2889+
if (!webgpu_ep) {
2890+
GTEST_SKIP() << "WebGPU EP not available";
2891+
}
2892+
2893+
constexpr int batch_size = 1;
2894+
constexpr int q_seq_len = 1;
2895+
constexpr int past_seq_len = 8;
2896+
constexpr int num_heads = 2;
2897+
constexpr int kv_num_heads = 1;
2898+
constexpr int head_size = 8;
2899+
constexpr int hidden_size = num_heads * head_size;
2900+
constexpr int kv_hidden_size = kv_num_heads * head_size;
2901+
2902+
std::vector<float> query_data(batch_size * q_seq_len * hidden_size);
2903+
std::vector<float> past_key_data(batch_size * kv_num_heads * past_seq_len * head_size);
2904+
std::vector<float> past_value_data(batch_size * kv_num_heads * past_seq_len * head_size);
2905+
for (size_t i = 0; i < query_data.size(); i++) query_data[i] = 0.1f * static_cast<float>(i % 7 + 1);
2906+
for (size_t i = 0; i < past_key_data.size(); i++) past_key_data[i] = 0.2f * static_cast<float>(i % 5 + 1);
2907+
for (size_t i = 0; i < past_value_data.size(); i++) past_value_data[i] = 0.3f * static_cast<float>(i % 3 + 1);
2908+
2909+
// WebGPU run: total_sequence_length=0 forces use_seqlen_k=true and
2910+
// use_indirect_dispatch=true; PrepareIndirectDispatchProgram sizes the
2911+
// dispatch from seqlen_k[0] on the GPU rather than the zero CPU value.
2912+
OpTester webgpu_tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
2913+
webgpu_tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
2914+
webgpu_tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
2915+
webgpu_tester.AddInput<float>("query", {batch_size, q_seq_len, hidden_size}, query_data);
2916+
webgpu_tester.AddInput<float>("key", {batch_size, 0, kv_hidden_size}, {});
2917+
webgpu_tester.AddInput<float>("value", {batch_size, 0, kv_hidden_size}, {});
2918+
webgpu_tester.AddInput<float>("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
2919+
webgpu_tester.AddInput<float>("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
2920+
webgpu_tester.AddInput<int32_t>("seqlens_k", {batch_size}, {static_cast<int32_t>(past_seq_len - 1)});
2921+
webgpu_tester.AddInput<int32_t>("total_sequence_length", {1}, {0}); // 0 → indirect dispatch path
2922+
webgpu_tester.AddOptionalInputEdge<float>(); // cos_cache
2923+
webgpu_tester.AddOptionalInputEdge<float>(); // sin_cache
2924+
webgpu_tester.AddOptionalInputEdge<int64_t>(); // position_ids
2925+
webgpu_tester.AddOptionalInputEdge<float>(); // attention_bias
2926+
webgpu_tester.AddOptionalInputEdge<float>(); // head_sink
2927+
const int output_size = batch_size * q_seq_len * hidden_size;
2928+
const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
2929+
webgpu_tester.AddOutput<float>("output", {batch_size, q_seq_len, hidden_size}, std::vector<float>(output_size, 0.0f));
2930+
webgpu_tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float>(present_size, 0.0f));
2931+
webgpu_tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float>(present_size, 0.0f));
2932+
webgpu_tester.SetOutputTolerance(1e6f);
2933+
std::vector<std::unique_ptr<IExecutionProvider>> webgpu_eps;
2934+
webgpu_eps.push_back(DefaultWebGpuExecutionProvider());
2935+
webgpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &webgpu_eps);
2936+
const float* webgpu_out = webgpu_tester.GetFetches()[0].Get<Tensor>().Data<float>();
2937+
std::vector<float> webgpu_output(webgpu_out, webgpu_out + output_size);
2938+
2939+
// CPU reference: use real total_sequence_length so CPU path is correct.
2940+
auto cpu_output = RunGQASharedKV(
2941+
batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data,
2942+
num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false);
2943+
2944+
ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_IndirectDispatch_Decode_WebGPU_vs_CPU");
2945+
}
2946+
2947+
// WebGPU: kv_empty + total_sequence_length=0, larger past (exercises tile boundary).
2948+
// past_seq_len=32 ensures num_total_seq_length_tile > 1, validating the tile
2949+
// count arithmetic in PrepareIndirectDispatchProgram more thoroughly.
2950+
TEST(GroupQueryAttentionTest, WebGPU_SharedKV_IndirectDispatch_LargerPast) {
2951+
auto webgpu_ep = DefaultWebGpuExecutionProvider();
2952+
if (!webgpu_ep) {
2953+
GTEST_SKIP() << "WebGPU EP not available";
2954+
}
2955+
2956+
constexpr int batch_size = 1;
2957+
constexpr int q_seq_len = 1;
2958+
constexpr int past_seq_len = 32;
2959+
constexpr int num_heads = 2;
2960+
constexpr int kv_num_heads = 1;
2961+
constexpr int head_size = 8;
2962+
constexpr int hidden_size = num_heads * head_size;
2963+
constexpr int kv_hidden_size = kv_num_heads * head_size;
2964+
2965+
std::vector<float> query_data(batch_size * q_seq_len * hidden_size);
2966+
std::vector<float> past_key_data(batch_size * kv_num_heads * past_seq_len * head_size);
2967+
std::vector<float> past_value_data(batch_size * kv_num_heads * past_seq_len * head_size);
2968+
for (size_t i = 0; i < query_data.size(); i++) query_data[i] = 0.1f * static_cast<float>(i % 7 + 1);
2969+
for (size_t i = 0; i < past_key_data.size(); i++) past_key_data[i] = 0.2f * static_cast<float>(i % 5 + 1);
2970+
for (size_t i = 0; i < past_value_data.size(); i++) past_value_data[i] = 0.3f * static_cast<float>(i % 3 + 1);
2971+
2972+
OpTester webgpu_tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
2973+
webgpu_tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
2974+
webgpu_tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
2975+
webgpu_tester.AddInput<float>("query", {batch_size, q_seq_len, hidden_size}, query_data);
2976+
webgpu_tester.AddInput<float>("key", {batch_size, 0, kv_hidden_size}, {});
2977+
webgpu_tester.AddInput<float>("value", {batch_size, 0, kv_hidden_size}, {});
2978+
webgpu_tester.AddInput<float>("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
2979+
webgpu_tester.AddInput<float>("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
2980+
webgpu_tester.AddInput<int32_t>("seqlens_k", {batch_size}, {static_cast<int32_t>(past_seq_len - 1)});
2981+
webgpu_tester.AddInput<int32_t>("total_sequence_length", {1}, {0}); // 0 → indirect dispatch path
2982+
webgpu_tester.AddOptionalInputEdge<float>(); // cos_cache
2983+
webgpu_tester.AddOptionalInputEdge<float>(); // sin_cache
2984+
webgpu_tester.AddOptionalInputEdge<int64_t>(); // position_ids
2985+
webgpu_tester.AddOptionalInputEdge<float>(); // attention_bias
2986+
webgpu_tester.AddOptionalInputEdge<float>(); // head_sink
2987+
const int output_size = batch_size * q_seq_len * hidden_size;
2988+
const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
2989+
webgpu_tester.AddOutput<float>("output", {batch_size, q_seq_len, hidden_size}, std::vector<float>(output_size, 0.0f));
2990+
webgpu_tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float>(present_size, 0.0f));
2991+
webgpu_tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float>(present_size, 0.0f));
2992+
webgpu_tester.SetOutputTolerance(1e6f);
2993+
std::vector<std::unique_ptr<IExecutionProvider>> webgpu_eps;
2994+
webgpu_eps.push_back(DefaultWebGpuExecutionProvider());
2995+
webgpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &webgpu_eps);
2996+
const float* webgpu_out = webgpu_tester.GetFetches()[0].Get<Tensor>().Data<float>();
2997+
std::vector<float> webgpu_output(webgpu_out, webgpu_out + output_size);
2998+
2999+
auto cpu_output = RunGQASharedKV(
3000+
batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data,
3001+
num_heads, kv_num_heads, head_size, /*use_cuda=*/false, /*use_webgpu=*/false);
3002+
3003+
ExpectOutputsMatch(webgpu_output, cpu_output, 0.05f, "SharedKV_IndirectDispatch_LargerPast_WebGPU_vs_CPU");
3004+
}
3005+
28753006
} // namespace test
28763007
} // namespace onnxruntime

0 commit comments

Comments
 (0)