@@ -2819,8 +2819,7 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_CUDA) {
28192819 RunBatchedRightPaddedRotaryPrefillForEP (GqaTargetEp::kCuda );
28202820}
28212821
2822- TEST (GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) {
2823- auto webgpu_ep = DefaultWebGpuExecutionProvider ();
2822+ TEST (GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { auto webgpu_ep = DefaultWebGpuExecutionProvider ();
28242823 if (!webgpu_ep) {
28252824 GTEST_SKIP () << " WebGPU EP not available" ;
28262825 }
@@ -2872,5 +2871,137 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_W
28722871 RunBatchedRightPaddedRotaryPrefillForEP (GqaTargetEp::kWebGpu , {4 , 2 , 6 }, /* smooth_softmax=*/ true );
28732872}
28742873
2874+ // ---------------------------------------------------------------------------
2875+ // Indirect dispatch tests (PrepareIndirectDispatchProgram)
2876+ //
2877+ // These tests pass total_sequence_length=0 to trigger the kv_empty indirect
2878+ // dispatch path: use_seqlen_k=true so seqlen_k tensor is used on the GPU, and
2879+ // use_indirect_dispatch=true so PrepareIndirectDispatchProgram sizes the flash-
2880+ // attention dispatch from the GPU-resident seqlen_k value instead of the zero
2881+ // CPU value. Correctness is verified by cross-checking against CPU with the
2882+ // real total.
2883+ // ---------------------------------------------------------------------------
2884+
2885+ // WebGPU: kv_empty + total_sequence_length=0, decode (q_seq=1).
2886+ // Exercises PrepareIndirectDispatchProgram for the kv_empty indirect dispatch path.
2887+ TEST (GroupQueryAttentionTest, WebGPU_SharedKV_IndirectDispatch_Decode) {
2888+ auto webgpu_ep = DefaultWebGpuExecutionProvider ();
2889+ if (!webgpu_ep) {
2890+ GTEST_SKIP () << " WebGPU EP not available" ;
2891+ }
2892+
2893+ constexpr int batch_size = 1 ;
2894+ constexpr int q_seq_len = 1 ;
2895+ constexpr int past_seq_len = 8 ;
2896+ constexpr int num_heads = 2 ;
2897+ constexpr int kv_num_heads = 1 ;
2898+ constexpr int head_size = 8 ;
2899+ constexpr int hidden_size = num_heads * head_size;
2900+ constexpr int kv_hidden_size = kv_num_heads * head_size;
2901+
2902+ std::vector<float > query_data (batch_size * q_seq_len * hidden_size);
2903+ std::vector<float > past_key_data (batch_size * kv_num_heads * past_seq_len * head_size);
2904+ std::vector<float > past_value_data (batch_size * kv_num_heads * past_seq_len * head_size);
2905+ for (size_t i = 0 ; i < query_data.size (); i++) query_data[i] = 0 .1f * static_cast <float >(i % 7 + 1 );
2906+ for (size_t i = 0 ; i < past_key_data.size (); i++) past_key_data[i] = 0 .2f * static_cast <float >(i % 5 + 1 );
2907+ for (size_t i = 0 ; i < past_value_data.size (); i++) past_value_data[i] = 0 .3f * static_cast <float >(i % 3 + 1 );
2908+
2909+ // WebGPU run: total_sequence_length=0 forces use_seqlen_k=true and
2910+ // use_indirect_dispatch=true; PrepareIndirectDispatchProgram sizes the
2911+ // dispatch from seqlen_k[0] on the GPU rather than the zero CPU value.
2912+ OpTester webgpu_tester (" GroupQueryAttention" , 1 , onnxruntime::kMSDomain );
2913+ webgpu_tester.AddAttribute <int64_t >(" num_heads" , static_cast <int64_t >(num_heads));
2914+ webgpu_tester.AddAttribute <int64_t >(" kv_num_heads" , static_cast <int64_t >(kv_num_heads));
2915+ webgpu_tester.AddInput <float >(" query" , {batch_size, q_seq_len, hidden_size}, query_data);
2916+ webgpu_tester.AddInput <float >(" key" , {batch_size, 0 , kv_hidden_size}, {});
2917+ webgpu_tester.AddInput <float >(" value" , {batch_size, 0 , kv_hidden_size}, {});
2918+ webgpu_tester.AddInput <float >(" past_key" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
2919+ webgpu_tester.AddInput <float >(" past_value" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
2920+ webgpu_tester.AddInput <int32_t >(" seqlens_k" , {batch_size}, {static_cast <int32_t >(past_seq_len - 1 )});
2921+ webgpu_tester.AddInput <int32_t >(" total_sequence_length" , {1 }, {0 }); // 0 → indirect dispatch path
2922+ webgpu_tester.AddOptionalInputEdge <float >(); // cos_cache
2923+ webgpu_tester.AddOptionalInputEdge <float >(); // sin_cache
2924+ webgpu_tester.AddOptionalInputEdge <int64_t >(); // position_ids
2925+ webgpu_tester.AddOptionalInputEdge <float >(); // attention_bias
2926+ webgpu_tester.AddOptionalInputEdge <float >(); // head_sink
2927+ const int output_size = batch_size * q_seq_len * hidden_size;
2928+ const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
2929+ webgpu_tester.AddOutput <float >(" output" , {batch_size, q_seq_len, hidden_size}, std::vector<float >(output_size, 0 .0f ));
2930+ webgpu_tester.AddOutput <float >(" present_key" , {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float >(present_size, 0 .0f ));
2931+ webgpu_tester.AddOutput <float >(" present_value" , {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float >(present_size, 0 .0f ));
2932+ webgpu_tester.SetOutputTolerance (1e6f);
2933+ std::vector<std::unique_ptr<IExecutionProvider>> webgpu_eps;
2934+ webgpu_eps.push_back (DefaultWebGpuExecutionProvider ());
2935+ webgpu_tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &webgpu_eps);
2936+ const float * webgpu_out = webgpu_tester.GetFetches ()[0 ].Get <Tensor>().Data <float >();
2937+ std::vector<float > webgpu_output (webgpu_out, webgpu_out + output_size);
2938+
2939+ // CPU reference: use real total_sequence_length so CPU path is correct.
2940+ auto cpu_output = RunGQASharedKV (
2941+ batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data,
2942+ num_heads, kv_num_heads, head_size, /* use_cuda=*/ false , /* use_webgpu=*/ false );
2943+
2944+ ExpectOutputsMatch (webgpu_output, cpu_output, 0 .05f , " SharedKV_IndirectDispatch_Decode_WebGPU_vs_CPU" );
2945+ }
2946+
2947+ // WebGPU: kv_empty + total_sequence_length=0, larger past (exercises tile boundary).
2948+ // past_seq_len=32 ensures num_total_seq_length_tile > 1, validating the tile
2949+ // count arithmetic in PrepareIndirectDispatchProgram more thoroughly.
2950+ TEST (GroupQueryAttentionTest, WebGPU_SharedKV_IndirectDispatch_LargerPast) {
2951+ auto webgpu_ep = DefaultWebGpuExecutionProvider ();
2952+ if (!webgpu_ep) {
2953+ GTEST_SKIP () << " WebGPU EP not available" ;
2954+ }
2955+
2956+ constexpr int batch_size = 1 ;
2957+ constexpr int q_seq_len = 1 ;
2958+ constexpr int past_seq_len = 32 ;
2959+ constexpr int num_heads = 2 ;
2960+ constexpr int kv_num_heads = 1 ;
2961+ constexpr int head_size = 8 ;
2962+ constexpr int hidden_size = num_heads * head_size;
2963+ constexpr int kv_hidden_size = kv_num_heads * head_size;
2964+
2965+ std::vector<float > query_data (batch_size * q_seq_len * hidden_size);
2966+ std::vector<float > past_key_data (batch_size * kv_num_heads * past_seq_len * head_size);
2967+ std::vector<float > past_value_data (batch_size * kv_num_heads * past_seq_len * head_size);
2968+ for (size_t i = 0 ; i < query_data.size (); i++) query_data[i] = 0 .1f * static_cast <float >(i % 7 + 1 );
2969+ for (size_t i = 0 ; i < past_key_data.size (); i++) past_key_data[i] = 0 .2f * static_cast <float >(i % 5 + 1 );
2970+ for (size_t i = 0 ; i < past_value_data.size (); i++) past_value_data[i] = 0 .3f * static_cast <float >(i % 3 + 1 );
2971+
2972+ OpTester webgpu_tester (" GroupQueryAttention" , 1 , onnxruntime::kMSDomain );
2973+ webgpu_tester.AddAttribute <int64_t >(" num_heads" , static_cast <int64_t >(num_heads));
2974+ webgpu_tester.AddAttribute <int64_t >(" kv_num_heads" , static_cast <int64_t >(kv_num_heads));
2975+ webgpu_tester.AddInput <float >(" query" , {batch_size, q_seq_len, hidden_size}, query_data);
2976+ webgpu_tester.AddInput <float >(" key" , {batch_size, 0 , kv_hidden_size}, {});
2977+ webgpu_tester.AddInput <float >(" value" , {batch_size, 0 , kv_hidden_size}, {});
2978+ webgpu_tester.AddInput <float >(" past_key" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_key_data);
2979+ webgpu_tester.AddInput <float >(" past_value" , {batch_size, kv_num_heads, past_seq_len, head_size}, past_value_data);
2980+ webgpu_tester.AddInput <int32_t >(" seqlens_k" , {batch_size}, {static_cast <int32_t >(past_seq_len - 1 )});
2981+ webgpu_tester.AddInput <int32_t >(" total_sequence_length" , {1 }, {0 }); // 0 → indirect dispatch path
2982+ webgpu_tester.AddOptionalInputEdge <float >(); // cos_cache
2983+ webgpu_tester.AddOptionalInputEdge <float >(); // sin_cache
2984+ webgpu_tester.AddOptionalInputEdge <int64_t >(); // position_ids
2985+ webgpu_tester.AddOptionalInputEdge <float >(); // attention_bias
2986+ webgpu_tester.AddOptionalInputEdge <float >(); // head_sink
2987+ const int output_size = batch_size * q_seq_len * hidden_size;
2988+ const int present_size = batch_size * kv_num_heads * past_seq_len * head_size;
2989+ webgpu_tester.AddOutput <float >(" output" , {batch_size, q_seq_len, hidden_size}, std::vector<float >(output_size, 0 .0f ));
2990+ webgpu_tester.AddOutput <float >(" present_key" , {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float >(present_size, 0 .0f ));
2991+ webgpu_tester.AddOutput <float >(" present_value" , {batch_size, kv_num_heads, past_seq_len, head_size}, std::vector<float >(present_size, 0 .0f ));
2992+ webgpu_tester.SetOutputTolerance (1e6f);
2993+ std::vector<std::unique_ptr<IExecutionProvider>> webgpu_eps;
2994+ webgpu_eps.push_back (DefaultWebGpuExecutionProvider ());
2995+ webgpu_tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &webgpu_eps);
2996+ const float * webgpu_out = webgpu_tester.GetFetches ()[0 ].Get <Tensor>().Data <float >();
2997+ std::vector<float > webgpu_output (webgpu_out, webgpu_out + output_size);
2998+
2999+ auto cpu_output = RunGQASharedKV (
3000+ batch_size, q_seq_len, past_seq_len, query_data, past_key_data, past_value_data,
3001+ num_heads, kv_num_heads, head_size, /* use_cuda=*/ false , /* use_webgpu=*/ false );
3002+
3003+ ExpectOutputsMatch (webgpu_output, cpu_output, 0 .05f , " SharedKV_IndirectDispatch_LargerPast_WebGPU_vs_CPU" );
3004+ }
3005+
28753006} // namespace test
28763007} // namespace onnxruntime
0 commit comments