Skip to content

Commit 771a4d4

Browse files
authored
[webgpu] Fused GeneratePositionIDs into FusedQKRotaryEmbedding (#26400)
### Description This PR fused GeneratePositionIDs into FusedQKRotaryEmbedding which can reduce one kernel call. ### Motivation and Context Previously, for GQA, the processing flow was: `SplitPackedQKVProgram -> GeneratePositionIDs -> FusedQKRotaryEmbedding -> FlashAttention` After this change, the pipeline becomes: `SplitPackedQKVProgram -> FusedQKRotaryEmbedding -> FlashAttention` on NV5080, the token generation speed improved ~4%(128tps->133tps)
1 parent 954bb7b commit 771a4d4

File tree

3 files changed

+10
-68
lines changed

3 files changed

+10
-68
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -67,49 +67,6 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu
6767
return context.RunProgram(program);
6868
}
6969

70-
Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const {
71-
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform);
72-
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
73-
sh.MainFunctionBody() << " var pos_id: i32 = 0;\n"
74-
<< " let batch_idx = global_idx / uniforms.sequence_length;\n"
75-
<< " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n"
76-
<< " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n";
77-
if (is_first_prompt_) {
78-
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
79-
<< " if (sequence_idx < total_seqlen) {\n"
80-
<< " pos_id = sequence_idx;\n"
81-
<< " } else {\n"
82-
<< " pos_id = 1;\n"
83-
<< " }\n"
84-
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
85-
} else if (is_subsequent_prompt_) {
86-
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
87-
<< " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n"
88-
<< " if (past_seqlen + sequence_idx < total_seqlen) {\n"
89-
<< " pos_id = past_seqlen + sequence_idx;\n"
90-
<< " } else {\n"
91-
<< " pos_id = 1;\n"
92-
<< " }\n"
93-
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
94-
} else {
95-
sh.MainFunctionBody() << " if (global_idx < uniforms.batch_size) {\n"
96-
<< " " << output.SetByOffset("global_idx", "seqlen") << "\n"
97-
<< " }\n";
98-
}
99-
return Status::OK();
100-
}
101-
102-
Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) {
103-
GeneratePositionIDsProgram program(params.is_first_prompt_, params.is_subsequent_prompt_);
104-
auto output_size = params.batch_size_ * params.sequence_length_;
105-
program.CacheHint(params.is_first_prompt_, params.is_subsequent_prompt_)
106-
.AddInput({seqlens, ProgramTensorMetadataDependency::Rank})
107-
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank})
108-
.AddUniformVariables({{static_cast<uint32_t>(params.batch_size_)}, {static_cast<uint32_t>(params.sequence_length_)}})
109-
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
110-
return context.RunProgram(program);
111-
}
112-
11370
// Fused Q/K rotary embedding
11471
Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
11572
const WebgpuAttentionParameters& params,
@@ -120,10 +77,6 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
12077
const Tensor* sin_cache,
12178
Tensor* query_out,
12279
Tensor* key_out) {
123-
Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(),
124-
TensorShape({params.batch_size_, params.sequence_length_}));
125-
ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, params, seqlen_k, &pos_ids));
126-
12780
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
12881
const auto head_size = params.head_size_;
12982

@@ -171,7 +124,7 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
171124
.AddInputs({
172125
{query_in, ProgramTensorMetadataDependency::Rank},
173126
{key_in, ProgramTensorMetadataDependency::Rank},
174-
{&pos_ids, ProgramTensorMetadataDependency::Rank},
127+
{seqlen_k, ProgramTensorMetadataDependency::Rank},
175128
{cos_cache, ProgramTensorMetadataDependency::Rank},
176129
{sin_cache, ProgramTensorMetadataDependency::Rank},
177130
})
@@ -188,8 +141,7 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
188141
{gsl::make_span(k_global_dims)},
189142
{gsl::make_span(k_input_output_strides)},
190143
{q_domain_size},
191-
})
192-
.AddIndices(TensorShape{1, 1});
144+
});
193145

194146
return context.RunProgram(program);
195147
}

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,6 @@ namespace webgpu {
1414

1515
using namespace onnxruntime::webgpu;
1616

17-
class GeneratePositionIDsProgram final : public Program<GeneratePositionIDsProgram> {
18-
public:
19-
GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt) {}
20-
21-
Status GenerateShaderCode(ShaderHelper& sh) const override;
22-
23-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32});
24-
25-
private:
26-
bool is_first_prompt_;
27-
bool is_subsequent_prompt_;
28-
};
29-
3017
class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
3118
public:
3219
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}

onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,12 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
5454
// Inputs
5555
const auto& q_input = shader.AddInput("q_input", ShaderUsage::UseUniform);
5656
const auto& k_input = shader.AddInput("k_input", ShaderUsage::UseUniform);
57-
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform);
57+
const auto& seqlens = shader.AddInput("seqlens", ShaderUsage::UseUniform);
5858
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
5959
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
6060
// Outputs
6161
const auto& q_output = shader.AddOutput("q_output", ShaderUsage::UseUniform);
6262
const auto& k_output = shader.AddOutput("k_output", ShaderUsage::UseUniform);
63-
// Indices helper
64-
const auto& dummy_indices = shader.AddIndices("dummy_indices", ShaderUsage::None);
6563

6664
const auto interleaved_str = interleaved_ ? "true" : "false";
6765

@@ -70,8 +68,13 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
7068
<< " let half_rotary_dim = uniforms.cos_cache_shape[1];\n"
7169
<< " let bsnh = global_idx / uniforms.q_global_stride % uniforms.q_global_shape;\n"
7270
<< " if (bsnh[3] < half_rotary_dim) {\n"
73-
<< " let pos_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", dummy_indices) << ";\n"
74-
<< " let position_id = u32(" << position_ids.GetByOffset("pos_ids_idx") << ") + select(0u, bsnh[1], pos_ids_idx == 0u);\n"
71+
<< " let batch_idx = bsnh[0];\n"
72+
<< " let sequence_idx = bsnh[1];\n"
73+
<< " let seqlen_i = " << seqlens.GetByOffset("batch_idx") << ";\n"
74+
<< " let seqlen = u32(seqlen_i);\n"
75+
<< " let total_seqlen = seqlen + 1u;\n"
76+
<< " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n"
77+
<< " let position_id = past_seqlen + sequence_idx;\n"
7578
<< " let cos_v = " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
7679
<< " let sin_v = " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
7780
<< " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"

0 commit comments

Comments
 (0)