Skip to content

Commit 83d11b5

Browse files
authored
[WebNN] Support more features for GQA (#27234)
Add support for GroupQueryAttention with: - do_rotary=true (cos_cache/sin_cache inputs) - Packed QKV (optional key/value inputs) - Optional past_key/past_value for prefill mode - Remove fp16->fp32 casting workaround Add ApplyRotaryEmbedding helper function. Fix decode stage by using qkv_sequence_length to distinguish prefill vs decode, and use runtime seqlens_k instead of static past_sequence_length for rotary position calculation.
1 parent a8ff3f3 commit 83d11b5

File tree

4 files changed

+642
-288
lines changed

4 files changed

+642
-288
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s
5252
| GlobalLpPool| ai.onnx(7+) | l2Pool2d | Only supports 4-D input, 'p' value is 2 |
5353
| Greater | ai.onnx(7-8, 9-12, 13+) | greater | |
5454
| GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | |
55-
| GroupQueryAttention | com.microsoft(1+) | add, cast, concat, constant, cumulativeSum, div, expand, lesser, matmul, reshape, scatterND, softmax, transpose, where | Only supports input total_sequence_length is constant and past_sequence_length of past kv equals to present_sequence_length of present kv. Does not support cos_cache and sin_cache inputs |
55+
| GroupQueryAttention | com.microsoft(1+) | add, cast, concat, constant, cumulativeSum, div, expand, lesser, matmul, reshape, scatterND, softmax, transpose, where | Only supports input total_sequence_length is constant and past_sequence_length of past kv equals to present_sequence_length of present kv. |
5656
| GRU | ai.onnx(7-13, 14-21, 22+) | gru | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
5757
| HardSigmoid | ai.onnx(7+) | hardSigmoid | |
5858
| HardSwish | ai.onnx(14+) | hardSwish | |

onnxruntime/core/providers/webnn/builders/impl/attention_helper.h

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,254 @@
44

55
#pragma once
66

7+
#include "core/providers/webnn/builders/helper.h"
8+
79
namespace onnxruntime {
810
namespace webnn {
11+
/*
12+
RotaryEmbedding Helper: Apply rotary positional embedding to input tensor.
13+
This helper function implements rotary embedding that can be reused by GQA and RotaryEmbedding ops.
14+
15+
The decomposed graph is referenced from DML EP at:
16+
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
17+
18+
Input CosCache PositionIds SinCache
19+
| | | |
20+
| | +--------+-----------+ |
21+
Split | | | |
22+
| | Gather Gather
23+
+-------+ | | |
24+
| | | |
25+
| Identity----------+ | |
26+
| | | | |
27+
| | | | |
28+
| --Split-- | | |
29+
| \ / | +-----------------+ |
30+
| \ / | | |
31+
| \ / Mul |
32+
| \ / | |
33+
| X | |
34+
| / \ | |
35+
| / \ | |
36+
| Join | |
37+
| | | |
38+
| | +---------------------------------------------------------+
39+
| | | |
40+
| Mul |
41+
| | |
42+
| +-----+ +------+
43+
| | |
44+
| Add
45+
| |
46+
+-------------+ |
47+
| |
48+
Join
49+
*/
50+
inline Status ApplyRotaryEmbedding(
51+
ModelBuilder& model_builder,
52+
const std::string& node_name,
53+
emscripten::val input, // Shape: [batch_size, sequence_length, num_heads, head_size]
54+
emscripten::val cos_cache, // Shape: [max_sequence_length, head_size / 2]
55+
emscripten::val sin_cache, // Shape: [max_sequence_length, head_size / 2]
56+
emscripten::val position_ids, // Shape: [batch_size, sequence_length] or [1]
57+
int32_t input_data_type,
58+
uint32_t batch_size,
59+
uint32_t sequence_length,
60+
uint32_t num_heads,
61+
uint32_t head_size,
62+
uint32_t rotary_embedding_dim,
63+
bool interleaved,
64+
bool has_position_ids,
65+
bool position_ids_is_offset,
66+
emscripten::val& output) {
67+
emscripten::val wnn_builder = model_builder.GetBuilder();
68+
ORT_RETURN_IF_NOT(head_size >= rotary_embedding_dim,
69+
"Rotary embedding dimension must be less than or equal to head_size");
70+
const uint32_t half_rotary_embedding_dim = rotary_embedding_dim / 2;
71+
72+
// Split the input to perform the rotary embedding only on a subregion of the tensor if needed.
73+
emscripten::val partial_input0 = input;
74+
emscripten::val partial_input1 = emscripten::val::undefined();
75+
if (head_size > rotary_embedding_dim) {
76+
const std::vector<uint32_t> splits{rotary_embedding_dim, head_size - rotary_embedding_dim};
77+
emscripten::val split_input_options = emscripten::val::object();
78+
split_input_options.set("label", node_name + "_rotary_split_input");
79+
split_input_options.set("axis", 3);
80+
emscripten::val split = wnn_builder.call<emscripten::val>(
81+
"split", input, emscripten::val::array(splits), split_input_options);
82+
partial_input0 = split[0];
83+
partial_input1 = split[1];
84+
}
85+
86+
// Split the partial input0 data into 2 equal parts.
87+
const std::vector<uint32_t> new_partial_input0_shape =
88+
interleaved ? std::vector<uint32_t>({batch_size, sequence_length, num_heads, half_rotary_embedding_dim, 2})
89+
: std::vector<uint32_t>({batch_size, sequence_length, num_heads, 2, half_rotary_embedding_dim});
90+
emscripten::val reshape_partial_input0_options = emscripten::val::object();
91+
reshape_partial_input0_options.set("label", node_name + "_rotary_reshape_partial_input0");
92+
partial_input0 = wnn_builder.call<emscripten::val>(
93+
"reshape", partial_input0, emscripten::val::array(new_partial_input0_shape), reshape_partial_input0_options);
94+
95+
// Split partial input0.
96+
const int split_axis = interleaved ? 4 : 3;
97+
emscripten::val split_partial_input0_options = emscripten::val::object();
98+
split_partial_input0_options.set("label", node_name + "_rotary_split_partial_input0");
99+
split_partial_input0_options.set("axis", split_axis);
100+
emscripten::val split_partial_input0 = wnn_builder.call<emscripten::val>(
101+
"split", partial_input0, 2, split_partial_input0_options);
102+
103+
// Swap the two halves and join them together.
104+
emscripten::val concat_partial_input0_options = emscripten::val::object();
105+
concat_partial_input0_options.set("label", node_name + "_rotary_concat_partial_input0");
106+
emscripten::val concated_partial_input0 = wnn_builder.call<emscripten::val>(
107+
"concat", split_partial_input0.call<emscripten::val>("reverse"), split_axis, concat_partial_input0_options);
108+
109+
emscripten::val gather_position_ids = position_ids;
110+
if (position_ids_is_offset) {
111+
// Generate a sequence from 0 to sequence_length and add the offset to it.
112+
const std::vector<uint32_t> position_ids_range_shape = {1, sequence_length};
113+
std::string typed_array_name = "BigInt64Array";
114+
int position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
115+
const bool is_int64_supported = model_builder.IsInt64Supported();
116+
if (!is_int64_supported) {
117+
typed_array_name = "Int32Array";
118+
position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT32;
119+
}
120+
emscripten::val position_ids_range_buffer = emscripten::val::global(typed_array_name.c_str()).new_(sequence_length);
121+
for (uint32_t i = 0; i < sequence_length; i++) {
122+
position_ids_range_buffer.set(i, is_int64_supported ? emscripten::val::global("BigInt")(i) : emscripten::val(i));
123+
}
124+
emscripten::val position_ids_range_desc = emscripten::val::object();
125+
position_ids_range_desc.set("shape", emscripten::val::array(position_ids_range_shape));
126+
position_ids_range_desc.set("dimensions", emscripten::val::array(position_ids_range_shape));
127+
ORT_RETURN_IF_NOT(SetWebnnDataType(position_ids_range_desc, position_ids_data_type),
128+
"WebNN backend does not support data type: ", position_ids_data_type);
129+
emscripten::val position_ids_range = wnn_builder.call<emscripten::val>(
130+
"constant", position_ids_range_desc, position_ids_range_buffer);
131+
emscripten::val position_ids_add_range_options = emscripten::val::object();
132+
position_ids_add_range_options.set("label", node_name + "_rotary_position_ids_add_range");
133+
gather_position_ids = wnn_builder.call<emscripten::val>(
134+
"add", position_ids, position_ids_range, position_ids_add_range_options);
135+
}
136+
137+
// Gather the cosine/sine values based on the position_ids (if it presents).
138+
emscripten::val gather_cos = cos_cache;
139+
emscripten::val gather_sin = sin_cache;
140+
if (has_position_ids) {
141+
emscripten::val gather_cos_options = emscripten::val::object();
142+
emscripten::val gather_sin_options = emscripten::val::object();
143+
gather_cos_options.set("label", node_name + "_rotary_gather_cos");
144+
gather_sin_options.set("label", node_name + "_rotary_gather_sin");
145+
gather_cos_options.set("axis", 0);
146+
gather_sin_options.set("axis", 0);
147+
gather_cos = wnn_builder.call<emscripten::val>("gather", gather_cos, gather_position_ids, gather_cos_options);
148+
gather_sin = wnn_builder.call<emscripten::val>("gather", gather_sin, gather_position_ids, gather_sin_options);
149+
} else {
150+
// When position_ids is not provided, slice the cos/sin cache to get the first sequence_length rows.
151+
// cos_cache/sin_cache shape: [max_sequence_length, half_rotary_embedding_dim]
152+
// After slice: [sequence_length, half_rotary_embedding_dim]
153+
emscripten::val slice_cos_options = emscripten::val::object();
154+
emscripten::val slice_sin_options = emscripten::val::object();
155+
slice_cos_options.set("label", node_name + "_rotary_slice_cos");
156+
slice_sin_options.set("label", node_name + "_rotary_slice_sin");
157+
const std::vector<uint32_t> slice_starts = {0, 0};
158+
const std::vector<uint32_t> slice_sizes = {sequence_length, half_rotary_embedding_dim};
159+
gather_cos = wnn_builder.call<emscripten::val>("slice", gather_cos,
160+
emscripten::val::array(slice_starts),
161+
emscripten::val::array(slice_sizes),
162+
slice_cos_options);
163+
gather_sin = wnn_builder.call<emscripten::val>("slice", gather_sin,
164+
emscripten::val::array(slice_starts),
165+
emscripten::val::array(slice_sizes),
166+
slice_sin_options);
167+
}
168+
169+
// Reshape and broadcast them to match the number of heads of the input data.
170+
const std::vector<uint32_t> reshaped_cos_sin_shape =
171+
interleaved ? std::vector<uint32_t>({batch_size, sequence_length, 1, half_rotary_embedding_dim, 1})
172+
: std::vector<uint32_t>({batch_size, sequence_length, 1, 1, half_rotary_embedding_dim});
173+
emscripten::val reshape_gather_cos_options = emscripten::val::object();
174+
emscripten::val reshape_gather_sin_options = emscripten::val::object();
175+
reshape_gather_cos_options.set("label", node_name + "_rotary_reshape_gather_cos");
176+
reshape_gather_sin_options.set("label", node_name + "_rotary_reshape_gather_sin");
177+
gather_cos = wnn_builder.call<emscripten::val>(
178+
"reshape", gather_cos, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_options);
179+
gather_sin = wnn_builder.call<emscripten::val>(
180+
"reshape", gather_sin, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_sin_options);
181+
182+
// Multiply the non-rotated data with the cosine and the rotated data with the sine.
183+
emscripten::val mul_cos_options = emscripten::val::object();
184+
mul_cos_options.set("label", node_name + "_rotary_mul_cos");
185+
emscripten::val mul_cos = wnn_builder.call<emscripten::val>(
186+
"mul", partial_input0, gather_cos, mul_cos_options);
187+
emscripten::val mul_sin_options = emscripten::val::object();
188+
mul_sin_options.set("label", node_name + "_rotary_mul_sin");
189+
emscripten::val mul_sin = wnn_builder.call<emscripten::val>(
190+
"mul", concated_partial_input0, gather_sin, mul_sin_options);
191+
192+
// Create a vector that contains the sign values {-1, 1}.
193+
emscripten::val sign_buffer = emscripten::val::undefined();
194+
const std::vector<uint32_t> sign_shape = interleaved ? std::vector<uint32_t>({1, 1, 1, 2})
195+
: std::vector<uint32_t>({1, 1, 2, 1});
196+
emscripten::val sign_constant_desc = emscripten::val::object();
197+
sign_constant_desc.set("shape", emscripten::val::array(sign_shape));
198+
sign_constant_desc.set("dimensions", emscripten::val::array(sign_shape));
199+
ORT_RETURN_IF_NOT(SetWebnnDataType(sign_constant_desc, input_data_type),
200+
"WebNN backend does not support data type: ", input_data_type);
201+
if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
202+
sign_buffer = emscripten::val::global("Float32Array").new_(2);
203+
sign_buffer.set(0, -1.0f);
204+
sign_buffer.set(1, 1.0f);
205+
} else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
206+
if (model_builder.IsFloat16ArrayAvailable()) {
207+
sign_buffer = emscripten::val::global("Float16Array").new_(2);
208+
sign_buffer.set(0, -1.0f);
209+
sign_buffer.set(1, 1.0f);
210+
} else {
211+
sign_buffer = emscripten::val::global("Uint16Array").new_(2);
212+
sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f));
213+
sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f));
214+
}
215+
} else {
216+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type for rotary embedding: ",
217+
input_data_type);
218+
}
219+
emscripten::val sign_constant = wnn_builder.call<emscripten::val>("constant", sign_constant_desc, sign_buffer);
220+
221+
// Multiply the broadcasted sign values with the rotated input.
222+
emscripten::val mul_sign_options = emscripten::val::object();
223+
mul_sign_options.set("label", node_name + "_rotary_mul_sign");
224+
mul_sin = wnn_builder.call<emscripten::val>("mul", mul_sin, sign_constant, mul_sign_options);
225+
226+
// Reshape mul_cos and mul_sin to (batch_size, sequence_length, num_heads, rotary_embedding_dim).
227+
const std::vector<uint32_t> reshaped_mul_cos_sin_shape =
228+
{batch_size, sequence_length, num_heads, rotary_embedding_dim};
229+
emscripten::val reshape_mul_cos_sin_options = emscripten::val::object();
230+
reshape_mul_cos_sin_options.set("label", node_name + "_rotary_reshape_mul_cos_sin");
231+
mul_cos = wnn_builder.call<emscripten::val>(
232+
"reshape", mul_cos, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options);
233+
mul_sin = wnn_builder.call<emscripten::val>(
234+
"reshape", mul_sin, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options);
235+
236+
// Add the multiplied cos and sin values together.
237+
emscripten::val add_mul_cos_sin_options = emscripten::val::object();
238+
add_mul_cos_sin_options.set("label", node_name + "_rotary_add_mul_cos_sin");
239+
output = wnn_builder.call<emscripten::val>(
240+
"add", mul_cos, mul_sin, add_mul_cos_sin_options);
241+
242+
// Join the added values with the rest of the input.
243+
if (head_size != rotary_embedding_dim) {
244+
emscripten::val concat_back_input_options = emscripten::val::object();
245+
concat_back_input_options.set("label", node_name + "_rotary_concat_back_input");
246+
emscripten::val concat_inputs = emscripten::val::array();
247+
concat_inputs.call<void>("push", output);
248+
concat_inputs.call<void>("push", partial_input1);
249+
output = wnn_builder.call<emscripten::val>("concat", concat_inputs, 3, concat_back_input_options);
250+
}
251+
252+
return Status::OK();
253+
}
254+
9255
/*
10256
ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention
11257
inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape)

0 commit comments

Comments
 (0)