Skip to content

Commit ae501ee

Browse files
[Native WebGPU EP] Add packedQKV and do_rotary attribute support to GroupQueryAttention operator (#23386)
### Description Add Packed QKV inputs and do_rotary attribute to GQA. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Packed QKV inputs and do_rotary attribute are required for certain models.
1 parent f22ee08 commit ae501ee

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "contrib_ops/webgpu/bert/attention_common.h"
66
#include "contrib_ops/webgpu/bert/group_query_attention.h"
77
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
#include "contrib_ops/webgpu/bert/rotary_embedding.h"
89
#include "contrib_ops/webgpu/bert/flash_attention.h"
910

1011
#include "core/providers/webgpu/webgpu_supported_types.h"
@@ -30,6 +31,117 @@ ONNX_OPERATOR_KERNEL_EX(
3031
.InputMemoryType(OrtMemTypeCPUInput, 6),
3132
GroupQueryAttention);
3233

34+
Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
35+
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
36+
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
37+
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
38+
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
39+
sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
40+
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
41+
<< " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n"
42+
<< " if (index < uniforms.hidden_size) {\n"
43+
<< " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n"
44+
<< " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
45+
<< " var key_indices = packed_qkv_indices;\n"
46+
<< " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n"
47+
<< " " << key.SetByIndices("key_indices", "input_data") << ";\n"
48+
<< " } else {\n"
49+
<< " var val_indices = packed_qkv_indices;\n"
50+
<< " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n"
51+
<< " " << value.SetByIndices("val_indices", "input_data") << ";\n"
52+
<< " }";
53+
return Status::OK();
54+
}
55+
56+
Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) {
57+
SplitPackedQKVProgram program;
58+
auto input_size = packedQKV->Shape().Size();
59+
program
60+
.AddInput({packedQKV, ProgramTensorMetadataDependency::Rank})
61+
.AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}})
62+
.AddUniformVariables({
63+
{static_cast<uint32_t>(params.hidden_size_)},
64+
{static_cast<uint32_t>(params.kv_hidden_size_)},
65+
})
66+
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
67+
return context.RunProgram(program);
68+
}
69+
70+
Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const {
71+
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform);
72+
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
73+
sh.MainFunctionBody() << " var pos_id: i32 = 0;\n"
74+
<< " let batch_idx = global_idx / uniforms.sequence_length;\n"
75+
<< " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n"
76+
<< " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n";
77+
if (is_first_prompt_) {
78+
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
79+
<< " if (sequence_idx < total_seqlen) {\n"
80+
<< " pos_id = sequence_idx;\n"
81+
<< " } else {\n"
82+
<< " pos_id = 1;\n"
83+
<< " }\n"
84+
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
85+
} else if (is_subsequent_prompt_) {
86+
sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n"
87+
<< " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n"
88+
<< " if (past_seqlen + sequence_idx < total_seqlen) {\n"
89+
<< " pos_id = past_seqlen + sequence_idx;\n"
90+
<< " } else {\n"
91+
<< " pos_id = 1;\n"
92+
<< " }\n"
93+
<< " " << output.SetByOffset("global_idx", "pos_id") << "\n";
94+
} else {
95+
sh.MainFunctionBody() << " if (global_idx < uniforms.batch_size) {\n"
96+
<< " " << output.SetByOffset("global_idx", "seqlen") << "\n"
97+
<< " }\n";
98+
}
99+
return Status::OK();
100+
}
101+
102+
Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) {
103+
GeneratePositionIDsProgram program(params.is_first_prompt_, params.is_subsequent_prompt_);
104+
auto output_size = params.batch_size_ * params.sequence_length_;
105+
program.CacheHint(params.is_first_prompt_, params.is_subsequent_prompt_)
106+
.AddInput({seqlens, ProgramTensorMetadataDependency::Rank})
107+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank})
108+
.AddUniformVariables({{static_cast<uint32_t>(params.batch_size_)}, {static_cast<uint32_t>(params.sequence_length_)}})
109+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
110+
return context.RunProgram(program);
111+
}
112+
113+
Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) {
114+
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
115+
const auto head_size = params.head_size_;
116+
const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_;
117+
const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast<int64_t>(head_size - half_rotary_embedding_dim)});
118+
const auto rank = global_shape.NumDimensions();
119+
std::vector<uint32_t> global_dims(rank);
120+
std::vector<uint32_t> global_strides(rank);
121+
for (size_t j = 0; j < rank; ++j) {
122+
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]);
123+
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1));
124+
}
125+
const auto input_output_strides = std::vector<uint32_t>({gsl::narrow_cast<uint32_t>(input->Shape().SizeFromDimension(1)), gsl::narrow_cast<uint32_t>(hidden_size), gsl::narrow_cast<uint32_t>(head_size), 1});
126+
const auto output_size = gsl::narrow_cast<const uint32_t>(global_shape.Size());
127+
128+
RotaryEmbeddingProgram program(params.rotary_interleaved_);
129+
program
130+
.CacheHint(params.rotary_interleaved_)
131+
.AddInputs({{input, ProgramTensorMetadataDependency::Rank},
132+
{pos_ids, ProgramTensorMetadataDependency::Rank},
133+
{cos_cache, ProgramTensorMetadataDependency::Rank},
134+
{sin_cache, ProgramTensorMetadataDependency::Rank}})
135+
.AddOutput(output)
136+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
137+
.AddUniformVariables({{params.scale_},
138+
{gsl::make_span(global_dims)},
139+
{gsl::make_span(global_strides)},
140+
{gsl::make_span(input_output_strides)}})
141+
.AddIndices(TensorShape{1, 1});
142+
return context.RunProgram(program);
143+
}
144+
33145
Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
34146
const Tensor* query = context.Input<Tensor>(0);
35147
const Tensor* key = context.Input<Tensor>(1);
@@ -41,7 +153,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
41153
const Tensor* cos_cache = context.Input<Tensor>(7);
42154
const Tensor* sin_cache = context.Input<Tensor>(8);
43155

44-
GroupQueryAttentionParameters params;
156+
GroupQueryAttentionParameters params = {};
45157
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
46158
key,
47159
value,
@@ -57,9 +169,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
57169
scale_,
58170
softcap_));
59171
WebgpuAttentionParameters parameters(params);
60-
if (parameters.is_packed_qkv_) {
61-
ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep.");
62-
}
63172
TensorShapeVector output_shape(3);
64173
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
65174
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
@@ -75,11 +184,39 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
75184
Tensor* present_value = context.Output(2, present_kv_shape);
76185
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
77186

78-
if (CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) {
187+
if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) {
79188
return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value,
80189
present_value, parameters, context);
81190
}
82191

192+
Tensor qSplit;
193+
Tensor kSplit;
194+
Tensor vSplit;
195+
if (parameters.is_packed_qkv_) {
196+
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
197+
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
198+
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
199+
ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit));
200+
parameters.is_packed_qkv_ = false;
201+
query = &qSplit;
202+
key = &kSplit;
203+
value = &vSplit;
204+
}
205+
206+
Tensor qRotary;
207+
Tensor kRotary;
208+
if (do_rotary_) {
209+
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
210+
kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
211+
auto pos_ids_shape = TensorShape({parameters.batch_size_, parameters.sequence_length_});
212+
Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape);
213+
ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlen_k, &pos_ids));
214+
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true));
215+
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_query_input = */ false));
216+
query = &qRotary;
217+
key = &kRotary;
218+
}
219+
83220
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
84221
parameters.sequence_length_, parameters.head_size_});
85222
TensorShape q_new_shape(q_new_dims);

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,29 @@ namespace webgpu {
1414

1515
using namespace onnxruntime::webgpu;
1616

17+
class GeneratePositionIDsProgram final : public Program<GeneratePositionIDsProgram> {
18+
public:
19+
GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt) {}
20+
21+
Status GenerateShaderCode(ShaderHelper& sh) const override;
22+
23+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32});
24+
25+
private:
26+
bool is_first_prompt_;
27+
bool is_subsequent_prompt_;
28+
};
29+
30+
class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
31+
public:
32+
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}
33+
34+
Status GenerateShaderCode(ShaderHelper& sh) const override;
35+
36+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32},
37+
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
38+
};
39+
1740
class GroupQueryAttention final : public WebGpuKernel {
1841
public:
1942
GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) {

0 commit comments

Comments
 (0)