@@ -67,49 +67,6 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu
6767 return context.RunProgram (program);
6868}
6969
70- Status GeneratePositionIDsProgram::GenerateShaderCode (ShaderHelper& sh) const {
71- const auto & output = sh.AddOutput (" output" , ShaderUsage::UseUniform);
72- const auto & seqlens = sh.AddInput (" seqlens" , ShaderUsage::UseUniform);
73- sh.MainFunctionBody () << " var pos_id: i32 = 0;\n "
74- << " let batch_idx = global_idx / uniforms.sequence_length;\n "
75- << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n "
76- << " let seqlen = " << seqlens.GetByOffset (" batch_idx" ) << " ;\n " ;
77- if (is_first_prompt_) {
78- sh.MainFunctionBody () << " let total_seqlen = seqlen + 1;\n "
79- << " if (sequence_idx < total_seqlen) {\n "
80- << " pos_id = sequence_idx;\n "
81- << " } else {\n "
82- << " pos_id = 1;\n "
83- << " }\n "
84- << " " << output.SetByOffset (" global_idx" , " pos_id" ) << " \n " ;
85- } else if (is_subsequent_prompt_) {
86- sh.MainFunctionBody () << " let total_seqlen = seqlen + 1;\n "
87- << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n "
88- << " if (past_seqlen + sequence_idx < total_seqlen) {\n "
89- << " pos_id = past_seqlen + sequence_idx;\n "
90- << " } else {\n "
91- << " pos_id = 1;\n "
92- << " }\n "
93- << " " << output.SetByOffset (" global_idx" , " pos_id" ) << " \n " ;
94- } else {
95- sh.MainFunctionBody () << " if (global_idx < uniforms.batch_size) {\n "
96- << " " << output.SetByOffset (" global_idx" , " seqlen" ) << " \n "
97- << " }\n " ;
98- }
99- return Status::OK ();
100- }
101-
102- Status GeneratePositionIDs (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) {
103- GeneratePositionIDsProgram program (params.is_first_prompt_ , params.is_subsequent_prompt_ );
104- auto output_size = params.batch_size_ * params.sequence_length_ ;
105- program.CacheHint (params.is_first_prompt_ , params.is_subsequent_prompt_ )
106- .AddInput ({seqlens, ProgramTensorMetadataDependency::Rank})
107- .AddOutput ({output_tensor, ProgramTensorMetadataDependency::Rank})
108- .AddUniformVariables ({{static_cast <uint32_t >(params.batch_size_ )}, {static_cast <uint32_t >(params.sequence_length_ )}})
109- .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE);
110- return context.RunProgram (program);
111- }
112-
11370// Fused Q/K rotary embedding
11471Status RunFusedQKRotaryEmbedding (onnxruntime::webgpu::ComputeContext& context,
11572 const WebgpuAttentionParameters& params,
@@ -120,10 +77,6 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
12077 const Tensor* sin_cache,
12178 Tensor* query_out,
12279 Tensor* key_out) {
123- Tensor pos_ids = context.CreateGPUTensor (DataTypeImpl::GetType<int64_t >(),
124- TensorShape ({params.batch_size_ , params.sequence_length_ }));
125- ORT_RETURN_IF_ERROR (GeneratePositionIDs (context, params, seqlen_k, &pos_ids));
126-
12780 const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
12881 const auto head_size = params.head_size_ ;
12982
@@ -171,7 +124,7 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
171124 .AddInputs ({
172125 {query_in, ProgramTensorMetadataDependency::Rank},
173126 {key_in, ProgramTensorMetadataDependency::Rank},
174- {&pos_ids , ProgramTensorMetadataDependency::Rank},
127+ {seqlen_k , ProgramTensorMetadataDependency::Rank},
175128 {cos_cache, ProgramTensorMetadataDependency::Rank},
176129 {sin_cache, ProgramTensorMetadataDependency::Rank},
177130 })
@@ -188,8 +141,7 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
188141 {gsl::make_span (k_global_dims)},
189142 {gsl::make_span (k_input_output_strides)},
190143 {q_domain_size},
191- })
192- .AddIndices (TensorShape{1 , 1 });
144+ });
193145
194146 return context.RunProgram (program);
195147}
0 commit comments