Skip to content

Commit 77d2cac

Browse files
committed
address the comment
1 parent 28db43a commit 77d2cac

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

onnxruntime/core/providers/webnn/builders/impl/attention_helper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,14 @@ inline Status ApplyRotaryEmbedding(
6565
bool position_ids_is_offset,
6666
emscripten::val& output) {
6767
emscripten::val wnn_builder = model_builder.GetBuilder();
68+
ORT_RETURN_IF_NOT(head_size >= rotary_embedding_dim,
69+
"Rotary embedding dimension must be less than or equal to head_size");
6870
const uint32_t half_rotary_embedding_dim = rotary_embedding_dim / 2;
6971

7072
// Split the input to perform the rotary embedding only on a subregion of the tensor if needed.
7173
emscripten::val partial_input0 = input;
7274
emscripten::val partial_input1 = emscripten::val::undefined();
73-
if (head_size != rotary_embedding_dim) {
75+
if (head_size > rotary_embedding_dim) {
7476
const std::vector<uint32_t> splits{rotary_embedding_dim, head_size - rotary_embedding_dim};
7577
emscripten::val split_input_options = emscripten::val::object();
7678
split_input_options.set("label", node_name + "_rotary_split_input");

0 commit comments

Comments
 (0)