55#include " contrib_ops/webgpu/bert/attention_common.h"
66#include " contrib_ops/webgpu/bert/group_query_attention.h"
77#include " contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+ #include " contrib_ops/webgpu/bert/rotary_embedding.h"
89#include " contrib_ops/webgpu/bert/flash_attention.h"
910
1011#include " core/providers/webgpu/webgpu_supported_types.h"
@@ -30,6 +31,117 @@ ONNX_OPERATOR_KERNEL_EX(
3031 .InputMemoryType(OrtMemTypeCPUInput, 6 ),
3132 GroupQueryAttention);
3233
34+ Status SplitPackedQKVProgram::GenerateShaderCode (ShaderHelper& sh) const {
35+ const auto & packed_qkv = sh.AddInput (" packed_qkv" , ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
36+ const auto & query = sh.AddOutput (" query" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
37+ const auto & key = sh.AddOutput (" key" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
38+ const auto & value = sh.AddOutput (" val" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
39+ sh.MainFunctionBody () << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices (" global_idx" ) << " ;\n "
40+ << " let input_data = " << packed_qkv.GetByOffset (" global_idx" ) << " ;\n "
41+ << " let index = " << packed_qkv.IndicesGet (" packed_qkv_indices" , " 2" ) << " ;\n "
42+ << " if (index < uniforms.hidden_size) {\n "
43+ << " " << query.SetByIndices (" packed_qkv_indices" , " input_data" ) << " ;\n "
44+ << " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n "
45+ << " var key_indices = packed_qkv_indices;\n "
46+ << " " << key.IndicesSet (" key_indices" , " 2" , " u32(index - uniforms.hidden_size)" ) << " ;\n "
47+ << " " << key.SetByIndices (" key_indices" , " input_data" ) << " ;\n "
48+ << " } else {\n "
49+ << " var val_indices = packed_qkv_indices;\n "
50+ << " " << value.IndicesSet (" val_indices" , " 2" , " u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)" ) << " ;\n "
51+ << " " << value.SetByIndices (" val_indices" , " input_data" ) << " ;\n "
52+ << " }" ;
53+ return Status::OK ();
54+ }
55+
56+ Status SplitPackedQKV (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) {
57+ SplitPackedQKVProgram program;
58+ auto input_size = packedQKV->Shape ().Size ();
59+ program
60+ .AddInput ({packedQKV, ProgramTensorMetadataDependency::Rank})
61+ .AddOutputs ({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}})
62+ .AddUniformVariables ({
63+ {static_cast <uint32_t >(params.hidden_size_ )},
64+ {static_cast <uint32_t >(params.kv_hidden_size_ )},
65+ })
66+ .SetDispatchGroupSize ((input_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE);
67+ return context.RunProgram (program);
68+ }
69+
70+ Status GeneratePositionIDsProgram::GenerateShaderCode (ShaderHelper& sh) const {
71+ const auto & output = sh.AddOutput (" output" , ShaderUsage::UseUniform);
72+ const auto & seqlens = sh.AddInput (" seqlens" , ShaderUsage::UseUniform);
73+ sh.MainFunctionBody () << " var pos_id: i32 = 0;\n "
74+ << " let batch_idx = global_idx / uniforms.sequence_length;\n "
75+ << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n "
76+ << " let seqlen = " << seqlens.GetByOffset (" batch_idx" ) << " ;\n " ;
77+ if (is_first_prompt_) {
78+ sh.MainFunctionBody () << " let total_seqlen = seqlen + 1;\n "
79+ << " if (sequence_idx < total_seqlen) {\n "
80+ << " pos_id = sequence_idx;\n "
81+ << " } else {\n "
82+ << " pos_id = 1;\n "
83+ << " }\n "
84+ << " " << output.SetByOffset (" global_idx" , " pos_id" ) << " \n " ;
85+ } else if (is_subsequent_prompt_) {
86+ sh.MainFunctionBody () << " let total_seqlen = seqlen + 1;\n "
87+ << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n "
88+ << " if (past_seqlen + sequence_idx < total_seqlen) {\n "
89+ << " pos_id = past_seqlen + sequence_idx;\n "
90+ << " } else {\n "
91+ << " pos_id = 1;\n "
92+ << " }\n "
93+ << " " << output.SetByOffset (" global_idx" , " pos_id" ) << " \n " ;
94+ } else {
95+ sh.MainFunctionBody () << " if (global_idx < uniforms.batch_size) {\n "
96+ << " " << output.SetByOffset (" global_idx" , " seqlen" ) << " \n "
97+ << " }\n " ;
98+ }
99+ return Status::OK ();
100+ }
101+
102+ Status GeneratePositionIDs (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) {
103+ GeneratePositionIDsProgram program (params.is_first_prompt_ , params.is_subsequent_prompt_ );
104+ auto output_size = params.batch_size_ * params.sequence_length_ ;
105+ program.CacheHint (params.is_first_prompt_ , params.is_subsequent_prompt_ )
106+ .AddInput ({seqlens, ProgramTensorMetadataDependency::Rank})
107+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::Rank})
108+ .AddUniformVariables ({{static_cast <uint32_t >(params.batch_size_ )}, {static_cast <uint32_t >(params.sequence_length_ )}})
109+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE);
110+ return context.RunProgram (program);
111+ }
112+
113+ Status RunRotaryEmbedding (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) {
114+ const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
115+ const auto head_size = params.head_size_ ;
116+ const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_ ;
117+ const TensorShape global_shape ({params.batch_size_ , params.sequence_length_ , hidden_size / head_size, static_cast <int64_t >(head_size - half_rotary_embedding_dim)});
118+ const auto rank = global_shape.NumDimensions ();
119+ std::vector<uint32_t > global_dims (rank);
120+ std::vector<uint32_t > global_strides (rank);
121+ for (size_t j = 0 ; j < rank; ++j) {
122+ global_dims[j] = gsl::narrow_cast<uint32_t >(global_shape[j]);
123+ global_strides[j] = gsl::narrow_cast<uint32_t >(global_shape.SizeFromDimension (j + 1 ));
124+ }
125+ const auto input_output_strides = std::vector<uint32_t >({gsl::narrow_cast<uint32_t >(input->Shape ().SizeFromDimension (1 )), gsl::narrow_cast<uint32_t >(hidden_size), gsl::narrow_cast<uint32_t >(head_size), 1 });
126+ const auto output_size = gsl::narrow_cast<const uint32_t >(global_shape.Size ());
127+
128+ RotaryEmbeddingProgram program (params.rotary_interleaved_ );
129+ program
130+ .CacheHint (params.rotary_interleaved_ )
131+ .AddInputs ({{input, ProgramTensorMetadataDependency::Rank},
132+ {pos_ids, ProgramTensorMetadataDependency::Rank},
133+ {cos_cache, ProgramTensorMetadataDependency::Rank},
134+ {sin_cache, ProgramTensorMetadataDependency::Rank}})
135+ .AddOutput (output)
136+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
137+ .AddUniformVariables ({{params.scale_ },
138+ {gsl::make_span (global_dims)},
139+ {gsl::make_span (global_strides)},
140+ {gsl::make_span (input_output_strides)}})
141+ .AddIndices (TensorShape{1 , 1 });
142+ return context.RunProgram (program);
143+ }
144+
33145Status GroupQueryAttention::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
34146 const Tensor* query = context.Input <Tensor>(0 );
35147 const Tensor* key = context.Input <Tensor>(1 );
@@ -41,7 +153,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
41153 const Tensor* cos_cache = context.Input <Tensor>(7 );
42154 const Tensor* sin_cache = context.Input <Tensor>(8 );
43155
44- GroupQueryAttentionParameters params;
156+ GroupQueryAttentionParameters params = {} ;
45157 ORT_RETURN_IF_ERROR (group_query_attention_helper::CheckInputs (query,
46158 key,
47159 value,
@@ -57,9 +169,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
57169 scale_,
58170 softcap_));
59171 WebgpuAttentionParameters parameters (params);
60- if (parameters.is_packed_qkv_ ) {
61- ORT_NOT_IMPLEMENTED (" Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep." );
62- }
63172 TensorShapeVector output_shape (3 );
64173 output_shape[0 ] = static_cast <int64_t >(parameters.batch_size_ );
65174 output_shape[1 ] = static_cast <int64_t >(parameters.sequence_length_ );
@@ -75,11 +184,39 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
75184 Tensor* present_value = context.Output (2 , present_kv_shape);
76185 parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw () == present_key->DataRaw () && past_value->DataRaw () == present_value->DataRaw ();
77186
78- if (CanApplyFlashAttention (nullptr /* bias */ , present_key, present_value, parameters, context)) {
187+ if (!do_rotary_ && CanApplyFlashAttention (nullptr /* bias */ , present_key, present_value, parameters, context)) {
79188 return ApplyFlashAttention (query, key, value, nullptr /* attention_bias */ , output, past_key, present_key, past_value,
80189 present_value, parameters, context);
81190 }
82191
192+ Tensor qSplit;
193+ Tensor kSplit ;
194+ Tensor vSplit;
195+ if (parameters.is_packed_qkv_ ) {
196+ qSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
197+ kSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
198+ vSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
199+ ORT_RETURN_IF_ERROR (SplitPackedQKV (context, parameters, query, &qSplit, &kSplit , &vSplit));
200+ parameters.is_packed_qkv_ = false ;
201+ query = &qSplit;
202+ key = &kSplit ;
203+ value = &vSplit;
204+ }
205+
206+ Tensor qRotary;
207+ Tensor kRotary ;
208+ if (do_rotary_) {
209+ qRotary = context.CreateGPUTensor (query->DataType (), query->Shape ());
210+ kRotary = context.CreateGPUTensor (key->DataType (), key->Shape ());
211+ auto pos_ids_shape = TensorShape ({parameters.batch_size_ , parameters.sequence_length_ });
212+ Tensor pos_ids = context.CreateGPUTensor (DataTypeImpl::GetType<int64_t >(), pos_ids_shape);
213+ ORT_RETURN_IF_ERROR (GeneratePositionIDs (context, parameters, seqlen_k, &pos_ids));
214+ ORT_RETURN_IF_ERROR (RunRotaryEmbedding (context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true ));
215+ ORT_RETURN_IF_ERROR (RunRotaryEmbedding (context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary , /* is_query_input = */ false ));
216+ query = &qRotary;
217+ key = &kRotary ;
218+ }
219+
83220 TensorShapeVector q_new_dims ({parameters.batch_size_ , parameters.num_heads_ ,
84221 parameters.sequence_length_ , parameters.head_size_ });
85222 TensorShape q_new_shape (q_new_dims);
0 commit comments