@@ -69,10 +69,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
6969 return context.RunProgram (program);
7070};
7171
72- void InitVarStub (std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt ) {
72+ void InitVarStub (std::ostringstream& ss, const Tensor* seqlen_k) {
7373 if (seqlen_k != nullptr ) {
7474 ss << " total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n " ;
75- ss << " var past_sequence_length: u32 = " << (is_first_prompt ? " 0 " : " total_sequence_length - sequence_length" ) << " ;\n " ;
75+ ss << " var past_sequence_length: u32 = select( total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0) ;\n " ;
7676 } else {
7777 ss << " let past_sequence_length = uniforms.past_sequence_length;\n " ;
7878 }
@@ -106,7 +106,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
106106 << " let sequence_length = uniforms.M;\n "
107107 << " var total_sequence_length = uniforms.N;\n " ;
108108 std::ostringstream oss;
109- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
109+ InitVarStub (oss, seqlen_k_);
110110 shader.MainFunctionBody () << oss.str ();
111111 shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112112 if (has_present_key_) {
@@ -121,7 +121,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
121121 " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n "
122122 " var idx = TILE_SIZE * local_id.y + local_id.x;\n " ;
123123
124- if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
124+ if ((feed_past_key_ && has_present_key_) || ( past_present_share_buffer_ && !is_first_prompt_) ) {
125125 shader.MainFunctionBody () << " if (n + local_id.y < past_sequence_length) {\n "
126126 << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.past_sequence_length * uniforms.K;\n "
127127 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_key" : " past_key" ) << " [pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n "
@@ -213,7 +213,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
213213 {static_cast <uint32_t >(past_sequence_length)},
214214 {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
215215 {static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
216- {static_cast <uint32_t >(parameters.n_reps )}})
216+ {static_cast <uint32_t >(parameters.n_reps )},
217+ {static_cast <uint32_t >(parameters.is_first_prompt_ ? 1 : 0 )}})
217218 .SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
218219
219220 return context.RunProgram (program);
@@ -231,7 +232,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
231232 << " let sequence_length = uniforms.sequence_length;\n "
232233 << " var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << " ;\n " ;
233234 std::ostringstream oss;
234- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
235+ InitVarStub (oss, seqlen_k_);
235236 shader.MainFunctionBody () << oss.str ()
236237 << " let local_offset = local_idx * uniforms.elements_per_thread;\n "
237238 << " let offset = (global_idx / " << work_group_size_ << " ) * uniforms.total_sequence_length_comp + local_offset;\n "
@@ -285,20 +286,21 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
285286 }
286287 const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1 ) / work_group_size;
287288
288- InPlaceSoftmaxProgram program{" InPlaceSoftmax" , work_group_size, components, is_first_prompt, seqlen_k};
289+ InPlaceSoftmaxProgram program{" InPlaceSoftmax" , work_group_size, components, seqlen_k};
289290 if (seqlen_k != nullptr ) {
290291 program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
291292 }
292293 program.AddOutputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
293- .CacheHint (work_group_size, is_first_prompt )
294+ .CacheHint (work_group_size)
294295 .SetDispatchGroupSize (1 , sequence_length, batch_size * num_heads)
295296 .SetWorkgroupSize (work_group_size)
296297 .AddUniformVariables ({{static_cast <uint32_t >(batch_size)},
297298 {static_cast <uint32_t >(num_heads)},
298299 {static_cast <uint32_t >(past_sequence_length)},
299300 {static_cast <uint32_t >(sequence_length)},
300301 {static_cast <uint32_t >(total_sequence_length_comp)},
301- {static_cast <uint32_t >(elementsPerThread)}});
302+ {static_cast <uint32_t >(elementsPerThread)},
303+ {static_cast <uint32_t >(is_first_prompt ? 1 : 0 )}});
302304
303305 return context.RunProgram (program);
304306}
@@ -327,7 +329,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
327329 << " let sequence_length = uniforms.M;\n "
328330 << " var total_sequence_length = uniforms.K;\n " ;
329331 std::ostringstream oss;
330- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
332+ InitVarStub (oss, seqlen_k_);
331333 shader.MainFunctionBody () << oss.str ();
332334 shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
333335 if (has_present_value_) {
@@ -342,12 +344,12 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
342344 << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n "
343345 << " var idx = TILE_SIZE * local_id.y + local_id.x;\n " ;
344346
345- if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
347+ if ((feed_past_value_ && has_present_value_) || ( past_present_share_buffer_ && !is_first_prompt_) ) {
346348 shader.MainFunctionBody () << " if (w + local_id.y < past_sequence_length) {\n "
347349 << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.past_sequence_length + n;\n "
348350 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_value" : " past_value" ) << " [pastValueOffset + (w + local_id.y) * uniforms.N];\n "
349351 << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
350- << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms. past_sequence_length) * uniforms.N];\n "
352+ << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n "
351353 << " }\n " ;
352354 } else {
353355 shader.MainFunctionBody () << " if (w + local_id.y < uniforms.kv_sequence_length) {\n "
@@ -425,7 +427,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
425427 {static_cast <uint32_t >(past_sequence_length)},
426428 {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
427429 {static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
428- {static_cast <uint32_t >(parameters.n_reps )}})
430+ {static_cast <uint32_t >(parameters.n_reps )},
431+ {static_cast <uint32_t >(parameters.is_first_prompt_ )}})
429432 .SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
430433
431434 return context.RunProgram (program);
0 commit comments