diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index ea88f291e5597..5a9ba3d357f12 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -52,7 +52,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | GlobalLpPool| ai.onnx(7+) | l2Pool2d | Only supports 4-D input, 'p' value is 2 | | Greater | ai.onnx(7-8, 9-12, 13+) | greater | | | GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | | -| GroupQueryAttention | com.microsoft(1+) | add, cast, concat, constant, cumulativeSum, div, expand, lesser, matmul, reshape, scatterND, softmax, transpose, where | Only supports input total_sequence_length is constant and past_sequence_length of past kv equals to present_sequence_length of present kv. Does not support cos_cache and sin_cache inputs | +| GroupQueryAttention | com.microsoft(1+) | add, cast, concat, constant, cumulativeSum, div, expand, lesser, matmul, reshape, scatterND, softmax, transpose, where | Only supports input total_sequence_length is constant and past_sequence_length of past kv equals to present_sequence_length of present kv. | | GRU | ai.onnx(7-13, 14-21, 22+) | gru | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | HardSigmoid | ai.onnx(7+) | hardSigmoid | | | HardSwish | ai.onnx(14+) | hardSwish | | diff --git a/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h b/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h index a0251406fc36b..e4ad884edac84 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h +++ b/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h @@ -4,8 +4,254 @@ #pragma once +#include "core/providers/webnn/builders/helper.h" + namespace onnxruntime { namespace webnn { +/* + RotaryEmbedding Helper: Apply rotary positional embedding to input tensor. + This helper function implements rotary embedding that can be reused by GQA and RotaryEmbedding ops. + + The decomposed graph is referenced from DML EP at: + onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp + + Input CosCache PositionIds SinCache + | | | | + | | +--------+-----------+ | + Split | | | | + | | Gather Gather + +-------+ | | | + | | | | + | Identity----------+ | | + | | | | | + | | | | | + | --Split-- | | | + | \ / | +-----------------+ | + | \ / | | | + | \ / Mul | + | \ / | | + | X | | + | / \ | | + | / \ | | + | Join | | + | | | | + | | +---------------------------------------------------------+ + | | | | + | Mul | + | | | + | +-----+ +------+ + | | | + | Add + | | + +-------------+ | + | | + Join +*/ +inline Status ApplyRotaryEmbedding( + ModelBuilder& model_builder, + const std::string& node_name, + emscripten::val input, // Shape: [batch_size, sequence_length, num_heads, head_size] + emscripten::val cos_cache, // Shape: [max_sequence_length, head_size / 2] + emscripten::val sin_cache, // Shape: [max_sequence_length, head_size / 2] + emscripten::val position_ids, // Shape: [batch_size, sequence_length] or [1] + int32_t input_data_type, + uint32_t batch_size, + uint32_t sequence_length, + uint32_t num_heads, + uint32_t head_size, + uint32_t rotary_embedding_dim, + bool interleaved, + bool has_position_ids, + bool position_ids_is_offset, + emscripten::val& output) { + emscripten::val wnn_builder = model_builder.GetBuilder(); + ORT_RETURN_IF_NOT(head_size >= rotary_embedding_dim, + "Rotary embedding dimension must be less than or equal to head_size"); + const uint32_t half_rotary_embedding_dim = rotary_embedding_dim / 2; + + // Split the input to perform the rotary embedding only on a subregion of the tensor if needed. + emscripten::val partial_input0 = input; + emscripten::val partial_input1 = emscripten::val::undefined(); + if (head_size > rotary_embedding_dim) { + const std::vector splits{rotary_embedding_dim, head_size - rotary_embedding_dim}; + emscripten::val split_input_options = emscripten::val::object(); + split_input_options.set("label", node_name + "_rotary_split_input"); + split_input_options.set("axis", 3); + emscripten::val split = wnn_builder.call( + "split", input, emscripten::val::array(splits), split_input_options); + partial_input0 = split[0]; + partial_input1 = split[1]; + } + + // Split the partial input0 data into 2 equal parts. + const std::vector new_partial_input0_shape = + interleaved ? std::vector({batch_size, sequence_length, num_heads, half_rotary_embedding_dim, 2}) + : std::vector({batch_size, sequence_length, num_heads, 2, half_rotary_embedding_dim}); + emscripten::val reshape_partial_input0_options = emscripten::val::object(); + reshape_partial_input0_options.set("label", node_name + "_rotary_reshape_partial_input0"); + partial_input0 = wnn_builder.call( + "reshape", partial_input0, emscripten::val::array(new_partial_input0_shape), reshape_partial_input0_options); + + // Split partial input0. + const int split_axis = interleaved ? 4 : 3; + emscripten::val split_partial_input0_options = emscripten::val::object(); + split_partial_input0_options.set("label", node_name + "_rotary_split_partial_input0"); + split_partial_input0_options.set("axis", split_axis); + emscripten::val split_partial_input0 = wnn_builder.call( + "split", partial_input0, 2, split_partial_input0_options); + + // Swap the two halves and join them together. + emscripten::val concat_partial_input0_options = emscripten::val::object(); + concat_partial_input0_options.set("label", node_name + "_rotary_concat_partial_input0"); + emscripten::val concated_partial_input0 = wnn_builder.call( + "concat", split_partial_input0.call("reverse"), split_axis, concat_partial_input0_options); + + emscripten::val gather_position_ids = position_ids; + if (position_ids_is_offset) { + // Generate a sequence from 0 to sequence_length and add the offset to it. + const std::vector position_ids_range_shape = {1, sequence_length}; + std::string typed_array_name = "BigInt64Array"; + int position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + const bool is_int64_supported = model_builder.IsInt64Supported(); + if (!is_int64_supported) { + typed_array_name = "Int32Array"; + position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } + emscripten::val position_ids_range_buffer = emscripten::val::global(typed_array_name.c_str()).new_(sequence_length); + for (uint32_t i = 0; i < sequence_length; i++) { + position_ids_range_buffer.set(i, is_int64_supported ? emscripten::val::global("BigInt")(i) : emscripten::val(i)); + } + emscripten::val position_ids_range_desc = emscripten::val::object(); + position_ids_range_desc.set("shape", emscripten::val::array(position_ids_range_shape)); + position_ids_range_desc.set("dimensions", emscripten::val::array(position_ids_range_shape)); + ORT_RETURN_IF_NOT(SetWebnnDataType(position_ids_range_desc, position_ids_data_type), + "WebNN backend does not support data type: ", position_ids_data_type); + emscripten::val position_ids_range = wnn_builder.call( + "constant", position_ids_range_desc, position_ids_range_buffer); + emscripten::val position_ids_add_range_options = emscripten::val::object(); + position_ids_add_range_options.set("label", node_name + "_rotary_position_ids_add_range"); + gather_position_ids = wnn_builder.call( + "add", position_ids, position_ids_range, position_ids_add_range_options); + } + + // Gather the cosine/sine values based on the position_ids (if it presents). + emscripten::val gather_cos = cos_cache; + emscripten::val gather_sin = sin_cache; + if (has_position_ids) { + emscripten::val gather_cos_options = emscripten::val::object(); + emscripten::val gather_sin_options = emscripten::val::object(); + gather_cos_options.set("label", node_name + "_rotary_gather_cos"); + gather_sin_options.set("label", node_name + "_rotary_gather_sin"); + gather_cos_options.set("axis", 0); + gather_sin_options.set("axis", 0); + gather_cos = wnn_builder.call("gather", gather_cos, gather_position_ids, gather_cos_options); + gather_sin = wnn_builder.call("gather", gather_sin, gather_position_ids, gather_sin_options); + } else { + // When position_ids is not provided, slice the cos/sin cache to get the first sequence_length rows. + // cos_cache/sin_cache shape: [max_sequence_length, half_rotary_embedding_dim] + // After slice: [sequence_length, half_rotary_embedding_dim] + emscripten::val slice_cos_options = emscripten::val::object(); + emscripten::val slice_sin_options = emscripten::val::object(); + slice_cos_options.set("label", node_name + "_rotary_slice_cos"); + slice_sin_options.set("label", node_name + "_rotary_slice_sin"); + const std::vector slice_starts = {0, 0}; + const std::vector slice_sizes = {sequence_length, half_rotary_embedding_dim}; + gather_cos = wnn_builder.call("slice", gather_cos, + emscripten::val::array(slice_starts), + emscripten::val::array(slice_sizes), + slice_cos_options); + gather_sin = wnn_builder.call("slice", gather_sin, + emscripten::val::array(slice_starts), + emscripten::val::array(slice_sizes), + slice_sin_options); + } + + // Reshape and broadcast them to match the number of heads of the input data. + const std::vector reshaped_cos_sin_shape = + interleaved ? std::vector({batch_size, sequence_length, 1, half_rotary_embedding_dim, 1}) + : std::vector({batch_size, sequence_length, 1, 1, half_rotary_embedding_dim}); + emscripten::val reshape_gather_cos_options = emscripten::val::object(); + emscripten::val reshape_gather_sin_options = emscripten::val::object(); + reshape_gather_cos_options.set("label", node_name + "_rotary_reshape_gather_cos"); + reshape_gather_sin_options.set("label", node_name + "_rotary_reshape_gather_sin"); + gather_cos = wnn_builder.call( + "reshape", gather_cos, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_options); + gather_sin = wnn_builder.call( + "reshape", gather_sin, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_sin_options); + + // Multiply the non-rotated data with the cosine and the rotated data with the sine. + emscripten::val mul_cos_options = emscripten::val::object(); + mul_cos_options.set("label", node_name + "_rotary_mul_cos"); + emscripten::val mul_cos = wnn_builder.call( + "mul", partial_input0, gather_cos, mul_cos_options); + emscripten::val mul_sin_options = emscripten::val::object(); + mul_sin_options.set("label", node_name + "_rotary_mul_sin"); + emscripten::val mul_sin = wnn_builder.call( + "mul", concated_partial_input0, gather_sin, mul_sin_options); + + // Create a vector that contains the sign values {-1, 1}. + emscripten::val sign_buffer = emscripten::val::undefined(); + const std::vector sign_shape = interleaved ? std::vector({1, 1, 1, 2}) + : std::vector({1, 1, 2, 1}); + emscripten::val sign_constant_desc = emscripten::val::object(); + sign_constant_desc.set("shape", emscripten::val::array(sign_shape)); + sign_constant_desc.set("dimensions", emscripten::val::array(sign_shape)); + ORT_RETURN_IF_NOT(SetWebnnDataType(sign_constant_desc, input_data_type), + "WebNN backend does not support data type: ", input_data_type); + if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + sign_buffer = emscripten::val::global("Float32Array").new_(2); + sign_buffer.set(0, -1.0f); + sign_buffer.set(1, 1.0f); + } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + if (model_builder.IsFloat16ArrayAvailable()) { + sign_buffer = emscripten::val::global("Float16Array").new_(2); + sign_buffer.set(0, -1.0f); + sign_buffer.set(1, 1.0f); + } else { + sign_buffer = emscripten::val::global("Uint16Array").new_(2); + sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); + sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type for rotary embedding: ", + input_data_type); + } + emscripten::val sign_constant = wnn_builder.call("constant", sign_constant_desc, sign_buffer); + + // Multiply the broadcasted sign values with the rotated input. + emscripten::val mul_sign_options = emscripten::val::object(); + mul_sign_options.set("label", node_name + "_rotary_mul_sign"); + mul_sin = wnn_builder.call("mul", mul_sin, sign_constant, mul_sign_options); + + // Reshape mul_cos and mul_sin to (batch_size, sequence_length, num_heads, rotary_embedding_dim). + const std::vector reshaped_mul_cos_sin_shape = + {batch_size, sequence_length, num_heads, rotary_embedding_dim}; + emscripten::val reshape_mul_cos_sin_options = emscripten::val::object(); + reshape_mul_cos_sin_options.set("label", node_name + "_rotary_reshape_mul_cos_sin"); + mul_cos = wnn_builder.call( + "reshape", mul_cos, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options); + mul_sin = wnn_builder.call( + "reshape", mul_sin, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options); + + // Add the multiplied cos and sin values together. + emscripten::val add_mul_cos_sin_options = emscripten::val::object(); + add_mul_cos_sin_options.set("label", node_name + "_rotary_add_mul_cos_sin"); + output = wnn_builder.call( + "add", mul_cos, mul_sin, add_mul_cos_sin_options); + + // Join the added values with the rest of the input. + if (head_size != rotary_embedding_dim) { + emscripten::val concat_back_input_options = emscripten::val::object(); + concat_back_input_options.set("label", node_name + "_rotary_concat_back_input"); + emscripten::val concat_inputs = emscripten::val::array(); + concat_inputs.call("push", output); + concat_inputs.call("push", partial_input1); + output = wnn_builder.call("concat", concat_inputs, 3, concat_back_input_options); + } + + return Status::OK(); +} + /* ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index a29fbdb91e79f..5bab962b238c0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -65,12 +65,18 @@ std::vector repeat_sequence(int32_t sequence_length, int32_t kv_num_hea Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length N is number of attention heads, kv_N is number of attention heads for kv, H is head size G is group size, and G=N/kv_N, W=N*H, h=Sqrt(H). - GQA inputs: query, key, value, past_key, past_value, seqlens_k, total_sequence_length - Notes: cos_cache, sin_cache inputs are not supported. If the data type of the inputs (qkv and past kv) is float16, - we cast them to float32 to ensure data precision. + GQA inputs: query, key(optional), value(optional), past_key(optional), past_value(optional), + seqlens_k, total_sequence_length, cos_cache(optional), sin_cache(optional), position_ids(optional) + Notes: + - key, value, past_key, past_value can be empty (optional inputs). + - When key/value are empty, query contains packed QKV. + - When past_key/past_value are empty, this is the first token (prefill mode). + - When do_rotary is true, cos_cache and sin_cache must be provided. query key value | | | + (RotaryEmb) (RotaryEmb) | + | | | Reshape Reshape Reshape (B,S,H,N) seqlens_k | | | / | | | past_value | (scatter_indices*) | @@ -95,27 +101,107 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + const int32_t local_window_size = helper.Get("local_window_size", -1); + const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0); + const uint32_t num_heads = helper.Get("num_heads", 0); + const bool do_rotary = static_cast(helper.Get("do_rotary", 0)); + const bool rotary_interleaved = static_cast(helper.Get("rotary_interleaved", 0)); + + // Check if optional inputs exist + const bool has_key = TensorExists(input_defs, 1); + const bool has_value = TensorExists(input_defs, 2); + const bool has_past_key = TensorExists(input_defs, 3); + const bool has_past_value = TensorExists(input_defs, 4); + const bool has_cos_cache = TensorExists(input_defs, 7); + const bool has_sin_cache = TensorExists(input_defs, 8); + const bool has_position_ids = TensorExists(input_defs, 9); + emscripten::val query_input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val key_input = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val value_input = model_builder.GetOperand(input_defs[2]->Name()); - emscripten::val past_key_input = model_builder.GetOperand(input_defs[3]->Name()); - emscripten::val past_value_input = model_builder.GetOperand(input_defs[4]->Name()); + emscripten::val key_input = has_key ? model_builder.GetOperand(input_defs[1]->Name()) : emscripten::val::undefined(); + emscripten::val value_input = has_value ? model_builder.GetOperand(input_defs[2]->Name()) : emscripten::val::undefined(); + emscripten::val past_key_input = has_past_key ? model_builder.GetOperand(input_defs[3]->Name()) : emscripten::val::undefined(); + emscripten::val past_value_input = has_past_value ? model_builder.GetOperand(input_defs[4]->Name()) : emscripten::val::undefined(); emscripten::val seqlens_k_input = model_builder.GetOperand(input_defs[5]->Name()); + emscripten::val cos_cache = has_cos_cache ? model_builder.GetOperand(input_defs[7]->Name()) : emscripten::val::undefined(); + emscripten::val sin_cache = has_sin_cache ? model_builder.GetOperand(input_defs[8]->Name()) : emscripten::val::undefined(); - std::vector input_q_shape, input_past_k_shape; + std::vector input_q_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_q_shape, logger), "Cannot get query shape"); - ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape"); - NodeAttrHelper helper(node); - const int32_t local_window_size = helper.Get("local_window_size", -1); - const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0); - const uint32_t num_heads = helper.Get("num_heads", 0); + // Calculate hidden_size and head_size based on whether key/value are provided + uint32_t qkv_hidden_size; + uint32_t head_size; + if (has_key) { + // query shape is (batch_size, sequence_length, num_heads * head_size) + qkv_hidden_size = SafeInt(input_q_shape[2]); + head_size = SafeInt(qkv_hidden_size / num_heads); + } else { + // query contains packed QKV: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + // hidden_size = num_heads * head_size, so we derive: head_size = d / (num_heads + 2 * kv_num_heads) + uint32_t d = SafeInt(input_q_shape[2]); + head_size = d / (num_heads + 2 * kv_num_heads); + qkv_hidden_size = num_heads * head_size; + } + + // Get past_sequence_length from past_key if available, otherwise it's 0 (first token) + uint32_t past_sequence_length = 0; + if (has_past_key) { + std::vector input_past_k_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape"); + past_sequence_length = SafeInt(input_past_k_shape[2]); + } const uint32_t batch_size = SafeInt(input_q_shape[0]); const uint32_t qkv_sequence_length = SafeInt(input_q_shape[1]); - const uint32_t qkv_hidden_size = SafeInt(input_q_shape[2]); - const uint32_t head_size = SafeInt(qkv_hidden_size / num_heads); - const uint32_t past_sequence_length = SafeInt(input_past_k_shape[2]); + + emscripten::val position_ids = emscripten::val::undefined(); + bool use_position_ids_as_offset = false; + if (has_position_ids) { + position_ids = model_builder.GetOperand(input_defs[9]->Name()); + } else { + // If position_ids is not provided, we need to derive it from the context. + // We distinguish prefill vs decode by qkv_sequence_length (not has_past_key), because + // with pre-allocated KV cache (freeDimensionOverrides), has_past_key is always true. + // + // - Prefill (qkv_sequence_length > 1): positions start from 0 + // - Decode (qkv_sequence_length == 1): position = seqlens_k (the actual sequence position) + // + // Note: We cannot use past_sequence_length from the static shape because it represents the + // pre-allocated cache size (total_sequence_length), not the actual number of valid tokens. + if (qkv_sequence_length == 1) { + // During decode, use seqlens_k as the position offset for rotary embedding. + // seqlens_k has shape [batch_size], but we need [batch_size, 1] to properly broadcast + // with position_ids_range which has shape [1, sequence_length] in ApplyRotaryEmbedding. + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_/GQA/seqlens_k_reshape_for_position"); + + emscripten::val reshaped_seqlens_k = model_builder.GetBuilder().call( + "reshape", seqlens_k_input, emscripten::val::array(std::vector({batch_size, 1})), reshape_options); + + // seqlens_k is INT32, but position_ids_range in ApplyRotaryEmbedding may be INT64 + // if int64 is supported. We need to cast to match the expected type. + if (model_builder.IsInt64Supported()) { + emscripten::val cast_options = emscripten::val::object(); + cast_options.set("label", node.Name() + "_/GQA/seqlens_k_cast_to_int64"); + position_ids = model_builder.GetBuilder().call( + "cast", reshaped_seqlens_k, emscripten::val("int64"), cast_options); + } else { + position_ids = reshaped_seqlens_k; + } + } else { + // During prefill, use 0 as the offset (positions will be 0, 1, 2, ..., sequence_length-1) + if (model_builder.IsInt64Supported()) { + position_ids = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_INT64, static_cast(0), {1}); + } else { + position_ids = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_INT32, static_cast(0), {1}); + } + } + use_position_ids_as_offset = true; + } + const uint32_t group_size = SafeInt(num_heads / kv_num_heads); const float scale_value = helper.Get("scale", 1 / sqrt(static_cast(head_size))); @@ -130,27 +216,90 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b int32_t q_type = 0; ORT_RETURN_IF_NOT(GetType(*input_defs[0], q_type, logger), "Could not get input data type."); - // Check whether inputs' data type is fp16, if so, we should cast them to fp32 to ensure the calculation precision. - if (q_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/GQA/preprocess/cast/query_input"); - query_input = model_builder.GetBuilder().call("cast", query_input, emscripten::val("float32"), - common_options); - - common_options.set("label", node.Name() + "_/GQA/preprocess/cast/key_input"); - key_input = - model_builder.GetBuilder().call("cast", key_input, emscripten::val("float32"), common_options); - - common_options.set("label", node.Name() + "_/GQA/preprocess/cast/value_input"); - value_input = model_builder.GetBuilder().call("cast", value_input, emscripten::val("float32"), - common_options); - - common_options.set("label", node.Name() + "_/GQA/preprocess/cast/past_key_input"); - past_key_input = model_builder.GetBuilder().call("cast", past_key_input, - emscripten::val("float32"), common_options); + // Split packed QKV if key and value are not provided separately + if (!has_key) { + // query contains packed QKV: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + const uint32_t kv_hidden_size = kv_num_heads * head_size; + const std::vector splits{qkv_hidden_size, kv_hidden_size, kv_hidden_size}; + emscripten::val split_options = emscripten::val::object(); + split_options.set("label", node.Name() + "_/GQA/split_packed_qkv"); + split_options.set("axis", 2); + emscripten::val split_result = model_builder.GetBuilder().call( + "split", query_input, emscripten::val::array(splits), split_options); + query_input = split_result[0]; + key_input = split_result[1]; + value_input = split_result[2]; + } - common_options.set("label", node.Name() + "_/GQA/preprocess/cast/past_value_input"); - past_value_input = model_builder.GetBuilder().call("cast", past_value_input, - emscripten::val("float32"), common_options); + // Apply rotary embedding if do_rotary is true + if (do_rotary && has_cos_cache && has_sin_cache) { + // Determine rotary_embedding_dim from cos_cache shape + std::vector cos_cache_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[7], cos_cache_shape, logger), "Cannot get cos_cache shape"); + const uint32_t rotary_embedding_dim = static_cast(cos_cache_shape[1] * 2); + + // Reshape query to (batch_size, sequence_length, num_heads, head_size) for rotary embedding + const std::vector query_reshape_for_rotary = {batch_size, qkv_sequence_length, num_heads, head_size}; + common_options.set("label", node.Name() + "_/GQA/query/reshape_for_rotary"); + emscripten::val reshaped_query_for_rotary = model_builder.GetBuilder().call( + "reshape", query_input, emscripten::val::array(query_reshape_for_rotary), common_options); + + // Apply rotary embedding to query + emscripten::val rotary_query_output; + ORT_RETURN_IF_ERROR(ApplyRotaryEmbedding( + model_builder, + node.Name() + "_query", + reshaped_query_for_rotary, + cos_cache, + sin_cache, + position_ids, + q_type, + batch_size, + qkv_sequence_length, + num_heads, + head_size, + rotary_embedding_dim, + rotary_interleaved, + true, + use_position_ids_as_offset, // position_ids_is_offset + rotary_query_output)); + + // Reshape back to (batch_size, sequence_length, hidden_size) + common_options.set("label", node.Name() + "_/GQA/query/reshape_after_rotary"); + query_input = model_builder.GetBuilder().call( + "reshape", rotary_query_output, emscripten::val::array(std::vector({batch_size, qkv_sequence_length, qkv_hidden_size})), common_options); + + // Reshape key to (batch_size, sequence_length, kv_num_heads, head_size) for rotary embedding + const std::vector key_reshape_for_rotary = {batch_size, qkv_sequence_length, kv_num_heads, head_size}; + common_options.set("label", node.Name() + "_/GQA/key/reshape_for_rotary"); + emscripten::val reshaped_key_for_rotary = model_builder.GetBuilder().call( + "reshape", key_input, emscripten::val::array(key_reshape_for_rotary), common_options); + + // Apply rotary embedding to key + emscripten::val rotary_key_output; + ORT_RETURN_IF_ERROR(ApplyRotaryEmbedding( + model_builder, + node.Name() + "_key", + reshaped_key_for_rotary, + cos_cache, + sin_cache, + position_ids, + q_type, + batch_size, + qkv_sequence_length, + kv_num_heads, + head_size, + rotary_embedding_dim, + rotary_interleaved, + true, + use_position_ids_as_offset, // position_ids_is_offset + rotary_key_output)); + + // Reshape back to (batch_size, sequence_length, kv_hidden_size) + const uint32_t kv_hidden_size = kv_num_heads * head_size; + common_options.set("label", node.Name() + "_/GQA/key/reshape_after_rotary"); + key_input = model_builder.GetBuilder().call( + "reshape", rotary_key_output, emscripten::val::array(std::vector({batch_size, qkv_sequence_length, kv_hidden_size})), common_options); } // Reshape and transpose the input "query" @@ -230,25 +379,51 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val scatter_indices = model_builder.GetBuilder().call( "reshape", pre_scatter_indices, emscripten::val::array(scatter_indices_shape), common_options); - // scatterND for present_key and present_value - common_options.set("label", node.Name() + "_/GQA/present_key/ScatterND"); - emscripten::val present_key = model_builder.GetBuilder().call( - "scatterND", past_key_input, scatter_indices, key_for_scatter, common_options); - - common_options.set("label", node.Name() + "_/GQA/present_value/ScatterND"); - emscripten::val present_value = model_builder.GetBuilder().call( - "scatterND", past_value_input, scatter_indices, value_for_scatter, common_options); + // scatterND for present_key and present_value, or use key/value directly if no past + emscripten::val present_key; + emscripten::val present_value; + if (has_past_key && has_past_value) { + common_options.set("label", node.Name() + "_/GQA/present_key/ScatterND"); + present_key = model_builder.GetBuilder().call( + "scatterND", past_key_input, scatter_indices, key_for_scatter, common_options); + + common_options.set("label", node.Name() + "_/GQA/present_value/ScatterND"); + present_value = model_builder.GetBuilder().call( + "scatterND", past_value_input, scatter_indices, value_for_scatter, common_options); + } else { + // No past_key/past_value, use key/value directly as present_key/present_value (first token case) + // Transpose key and value to BNSH format: (B, S, kv_N, H) -> (B, kv_N, S, H) + transpose_options.set("permutation", emscripten::val::array(std::vector({0, 2, 1, 3}))); + transpose_options.set("label", node.Name() + "_/GQA/key/transpose_to_bnsh"); + present_key = model_builder.GetBuilder().call("transpose", key_for_scatter, transpose_options); + + transpose_options.set("label", node.Name() + "_/GQA/value/transpose_to_bnsh"); + present_value = model_builder.GetBuilder().call("transpose", value_for_scatter, transpose_options); + } emscripten::val true_present_key; emscripten::val true_present_value; + // If no past_key, the sequence length is qkv_sequence_length, otherwise it is past_sequence_length. + // In prefill stage, sequence_length == total_sequence_length. + // In decoding stage, past_sequence_length == total_sequence_length. + // So we can use total_sequence_length (which is provided in input[6]) or derive it from the logic above. + uint32_t current_total_seq_len; + if (!has_past_key && !has_past_value) { + // Prefill: current_total_seq_len is simply qkv_sequence_length (which should == total_sequence_length) + current_total_seq_len = qkv_sequence_length; + } else { + // Decoding: current_total_seq_len is past_sequence_length + current_total_seq_len = past_sequence_length; + } + if (group_size != 1) { // Broadcast key and value for group query by reshape, expand and reshape. // present kv shape (B,kv_N,P,H) -> (B,kv_N,1,P,H) -> (B,kv_N,N/kv_N,P,H) -> (B,N,P,H) broadcasted kv shape - const std::vector group_broadcast_tensor_shape_1 = {batch_size, kv_num_heads, 1, past_sequence_length, + const std::vector group_broadcast_tensor_shape_1 = {batch_size, kv_num_heads, 1, current_total_seq_len, head_size}; const std::vector group_broadcast_tensor_shape_2 = {batch_size, kv_num_heads, group_size, - past_sequence_length, head_size}; - const std::vector group_broadcast_tensor_shape_3 = {batch_size, num_heads, past_sequence_length, + current_total_seq_len, head_size}; + const std::vector group_broadcast_tensor_shape_3 = {batch_size, num_heads, current_total_seq_len, head_size}; common_options.set("label", node.Name() + "_/GQA/true_present_key/reshape_1"); true_present_key = model_builder.GetBuilder().call( @@ -279,8 +454,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b transpose_options.set("label", node.Name() + "_/GQA/present_key/transpose"); true_present_key = model_builder.GetBuilder().call("transpose", true_present_key, transpose_options); - emscripten::val scale_constant = - model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, scale_value, {1}); + emscripten::val scale_constant = model_builder.CreateOrGetConstant(q_type, scale_value, {1}); /* Calculate attention_bias for masking softmax ones_array (shape=B,N,S,P) range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S) @@ -298,7 +472,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val value_int_one_constant = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1}); - std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length}; + std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, current_total_seq_len}; common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand"); emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call( "expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options); @@ -310,7 +484,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val neq_left = model_builder.GetBuilder().call( "cumulativeSum", mask_shape_ones_shape_constant, gsl::narrow(3), cumsum_options); - std::vector reshape_pre_neq_right = {past_sequence_length, qkv_sequence_length}; + std::vector reshape_pre_neq_right = {current_total_seq_len, qkv_sequence_length}; std::vector pre_neq_right_data_range(qkv_sequence_length); std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1); @@ -373,15 +547,15 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b "logicalAnd", condition_1, condition_2, common_options); } - emscripten::val value_one_constant = - model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1}); + // For attended positions, use 0.0 (no change to attention scores) + // For masked positions, use a very large negative number (softmax → 0) + emscripten::val value_zero_constant_float = model_builder.CreateOrGetConstant(q_type, 0, {1}); // finfo_min: the minimum value of float32 - emscripten::val finfo_min_constant = model_builder.CreateOrGetConstant( - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, -3.4028234663852886e+38, {1}); + emscripten::val finfo_min_constant = model_builder.CreateOrGetConstant(q_type, -3.4028234663852886e+38, {1}); common_options.set("label", node.Name() + "_/GQA/attn_mask/where"); - emscripten::val attn_mask = model_builder.GetBuilder().call("where", condition, value_one_constant, + emscripten::val attn_mask = model_builder.GetBuilder().call("where", condition, value_zero_constant_float, finfo_min_constant, common_options); // Execute ScaledDotProductAttention @@ -389,20 +563,6 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b ScaledDotProductAttention(model_builder, node, logger, new_query, true_present_key, true_present_value, scale_constant, attn_mask, reshape_output_shape); - if (q_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/GQA/postprocess/cast/output"); - output = - model_builder.GetBuilder().call("cast", output, emscripten::val("float16"), common_options); - - common_options.set("label", node.Name() + "_/GQA/postprocess/cast/present_key"); - present_key = model_builder.GetBuilder().call("cast", present_key, emscripten::val("float16"), - common_options); - - common_options.set("label", node.Name() + "_/GQA/postprocess/cast/present_value"); - present_value = model_builder.GetBuilder().call("cast", present_value, emscripten::val("float16"), - common_options); - } - model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); model_builder.AddOperand(node.OutputDefs()[1]->Name(), std::move(present_key)); model_builder.AddOperand(node.OutputDefs()[2]->Name(), std::move(present_value)); @@ -419,6 +579,16 @@ bool GroupQueryAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vi const auto& op_type = node.OpType(); NodeAttrHelper helper(node); + const int64_t do_rotary = helper.Get("do_rotary", static_cast(0)); + + // When do_rotary is true, cos_cache and sin_cache must be provided + if (do_rotary) { + if (!TensorExists(input_defs, 7) || !TensorExists(input_defs, 8)) { + LOGS(logger, VERBOSE) << op_type << " requires cos_cache and sin_cache when do_rotary is true"; + return false; + } + } + const auto& total_sequence_length_name = input_defs[6]->Name(); const auto* total_sequence_length_initializer = graph_viewer.GetConstantInitializer(total_sequence_length_name); if (!total_sequence_length_initializer) { @@ -439,12 +609,17 @@ bool GroupQueryAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vi } const auto sequence_length = query_shape[1]; - std::vector past_key_shape; - if (!GetShape(*input_defs[3], past_key_shape, logger)) { - LOGS(logger, VERBOSE) << "Cannot get past_key shape."; - return false; + // Check if past_key exists to determine past_sequence_length + const bool has_past_key = TensorExists(input_defs, 3); + int64_t past_sequence_length = 0; + if (has_past_key) { + std::vector past_key_shape; + if (!GetShape(*input_defs[3], past_key_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get past_key shape."; + return false; + } + past_sequence_length = past_key_shape[2]; } - const auto past_sequence_length = past_key_shape[2]; // WebNN EP only supports past_sequence_length of past kv equals to present_sequence_length of present kv // According to CPU EP, present_sequence_length = max(past_sequence_length,total_sequence_length) @@ -455,7 +630,7 @@ bool GroupQueryAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vi return false; } } else { // For decoding stage, it requires past_sequence_length == total_sequence_length. - if (past_sequence_length != total_sequence_length.as()) { + if (has_past_key && past_sequence_length != total_sequence_length.as()) { LOGS(logger, VERBOSE) << op_type << " past_sequence_length != total_sequence_length."; return false; } @@ -475,68 +650,154 @@ bool GroupQueryAttentionOpBuilder::HasSupportedInputsImpl(const GraphViewer&, co const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); + NodeAttrHelper helper(node); - for (int i = 0; i < 9; i++) { - if (i < 7) { - if (!TensorExists(input_defs, i)) { - LOGS(logger, VERBOSE) << op_type << " requires input " << i; - return false; - } - } else { // cos_cache and sin_cache are not supported - if (TensorExists(input_defs, i)) { - LOGS(logger, VERBOSE) << op_type << " does not support input " << i; - return false; - } + const int64_t do_rotary = helper.Get("do_rotary", static_cast(0)); + + // Validate required inputs: query(0), seqlens_k(5), total_sequence_length(6) are always required + // key(1), value(2), past_key(3), past_value(4) are optional + // cos_cache(7), sin_cache(8) are required when do_rotary is true + // position_ids(9), attention_bias(10), head_sink(11) are optional + + // Check required inputs + if (!TensorExists(input_defs, 0)) { + LOGS(logger, VERBOSE) << op_type << " requires query input (index 0)"; + return false; + } + if (!TensorExists(input_defs, 5)) { + LOGS(logger, VERBOSE) << op_type << " requires seqlens_k input (index 5)"; + return false; + } + if (!TensorExists(input_defs, 6)) { + LOGS(logger, VERBOSE) << op_type << " requires total_sequence_length input (index 6)"; + return false; + } + + // Check key/value pair consistency + const bool has_key = TensorExists(input_defs, 1); + const bool has_value = TensorExists(input_defs, 2); + if (has_key != has_value) { + LOGS(logger, VERBOSE) << op_type << " key and value must both be present or both be absent"; + return false; + } + + // Check past_key/past_value pair consistency + const bool has_past_key = TensorExists(input_defs, 3); + const bool has_past_value = TensorExists(input_defs, 4); + if (has_past_key != has_past_value) { + LOGS(logger, VERBOSE) << op_type << " past_key and past_value must both be present or both be absent"; + return false; + } + + // Check do_rotary requirements + const bool has_cos_cache = TensorExists(input_defs, 7); + const bool has_sin_cache = TensorExists(input_defs, 8); + if (do_rotary) { + if (!has_cos_cache || !has_sin_cache) { + LOGS(logger, VERBOSE) << op_type << " requires cos_cache and sin_cache when do_rotary is true"; + return false; } } + // Get query type (required) int32_t q_type = 0; - int32_t k_type = 0; - int32_t v_type = 0; - int32_t past_k_type = 0; - int32_t past_v_type = 0; - int32_t seqlens_k_type = 0; - int32_t total_sequence_length_type = 0; - if (!GetType(*input_defs[0], q_type, logger) || !GetType(*input_defs[1], k_type, logger) || - !GetType(*input_defs[2], v_type, logger) || !GetType(*input_defs[3], past_k_type, logger) || - !GetType(*input_defs[4], past_v_type, logger) || !GetType(*input_defs[5], seqlens_k_type, logger) || - !GetType(*input_defs[6], total_sequence_length_type, logger)) { + if (!GetType(*input_defs[0], q_type, logger)) { return false; } - std::array input_types{q_type, k_type, v_type, past_k_type, past_v_type}; - if (!AreDataTypesSame(op_type, input_types, logger)) { + // Check optional key/value types + if (has_key) { + int32_t k_type = 0; + int32_t v_type = 0; + if (!GetType(*input_defs[1], k_type, logger) || !GetType(*input_defs[2], v_type, logger)) { + return false; + } + std::array qkv_types{q_type, k_type, v_type}; + if (!AreDataTypesSame(op_type, qkv_types, logger)) { + return false; + } + } + + // Check optional past_key/past_value types + if (has_past_key) { + int32_t past_k_type = 0; + int32_t past_v_type = 0; + if (!GetType(*input_defs[3], past_k_type, logger) || !GetType(*input_defs[4], past_v_type, logger)) { + return false; + } + std::array past_types{q_type, past_k_type, past_v_type}; + if (!AreDataTypesSame(op_type, past_types, logger)) { + return false; + } + } + + // Check seqlens_k and total_sequence_length types + int32_t seqlens_k_type = 0; + int32_t total_sequence_length_type = 0; + if (!GetType(*input_defs[5], seqlens_k_type, logger) || !GetType(*input_defs[6], total_sequence_length_type, logger)) { return false; } if (q_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && q_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + LOGS(logger, VERBOSE) << op_type << " query type must be float or float16"; return false; } - if (seqlens_k_type != ONNX_NAMESPACE::TensorProto_DataType_INT32 && + if (seqlens_k_type != ONNX_NAMESPACE::TensorProto_DataType_INT32 || total_sequence_length_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) { + LOGS(logger, VERBOSE) << op_type << " seqlens_k and total_sequence_length must be int32"; return false; } - std::vector input_q_shape, input_k_shape, input_v_shape, input_past_k_shape, input_past_v_shape; - if (!GetShape(*input_defs[0], input_q_shape, logger) || !GetShape(*input_defs[1], input_k_shape, logger) || - !GetShape(*input_defs[2], input_v_shape, logger) || !GetShape(*input_defs[3], input_past_k_shape, logger) || - !GetShape(*input_defs[4], input_past_v_shape, logger)) { + // Check cos_cache/sin_cache types when do_rotary is true + if (do_rotary && has_cos_cache && has_sin_cache) { + int32_t cos_cache_type = 0; + int32_t sin_cache_type = 0; + if (!GetType(*input_defs[7], cos_cache_type, logger) || !GetType(*input_defs[8], sin_cache_type, logger)) { + return false; + } + std::array cache_types{q_type, cos_cache_type, sin_cache_type}; + if (!AreDataTypesSame(op_type, cache_types, logger)) { + return false; + } + } + + // Check shapes + std::vector input_q_shape; + if (!GetShape(*input_defs[0], input_q_shape, logger)) { return false; } const auto q_rank = input_q_shape.size(); - const auto k_rank = input_k_shape.size(); - const auto v_rank = input_v_shape.size(); - const auto past_k_rank = input_past_k_shape.size(); - const auto past_v_rank = input_past_v_shape.size(); - if (q_rank != 3 || k_rank != 3 || v_rank != 3) { // The qkv shape should be BSW - LOGS(logger, VERBOSE) << op_type << " qkv shape is not BSW."; + if (q_rank != 3) { // The query shape should be BSW + LOGS(logger, VERBOSE) << op_type << " query shape is not BSW."; return false; } - if (past_k_rank != 4 || past_v_rank != 4) { // The past qkv shape should be BNSH - LOGS(logger, VERBOSE) << op_type << " past qkv shape is not BNSH."; - return false; + if (has_key) { + std::vector input_k_shape, input_v_shape; + if (!GetShape(*input_defs[1], input_k_shape, logger) || !GetShape(*input_defs[2], input_v_shape, logger)) { + return false; + } + const auto k_rank = input_k_shape.size(); + const auto v_rank = input_v_shape.size(); + if (k_rank != 3 || v_rank != 3) { // The kv shape should be BSW + LOGS(logger, VERBOSE) << op_type << " key/value shape is not BSW."; + return false; + } + } + + if (has_past_key) { + std::vector input_past_k_shape, input_past_v_shape; + if (!GetShape(*input_defs[3], input_past_k_shape, logger) || + !GetShape(*input_defs[4], input_past_v_shape, logger)) { + return false; + } + const auto past_k_rank = input_past_k_shape.size(); + const auto past_v_rank = input_past_v_shape.size(); + if (past_k_rank != 4 || past_v_rank != 4) { // The past qkv shape should be BNSH + LOGS(logger, VERBOSE) << op_type << " past qkv shape is not BNSH."; + return false; + } } return true; diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 4395c2854dcfb..c1f9cb57b2676 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -9,6 +9,7 @@ #include "core/providers/webnn/builders/op_builder_factory.h" #include "base_op_builder.h" +#include "attention_helper.h" // WebNN doesn't provide a dedicated op for RotaryEmbedding. Instead, we implement it by using a // combination of WebNN ops. The decomposed graph is referenced from DML EP at: @@ -92,7 +93,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build const bool position_ids_is_offset = has_position_ids && position_ids_shape.size() == 1; emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val position_ids; + emscripten::val position_ids = emscripten::val::undefined(); if (has_position_ids) { position_ids = model_builder.GetOperand(input_defs[position_ids_idx]->Name()); } @@ -138,7 +139,6 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build rotary_embedding_dim = head_size; } - const uint32_t half_rotary_embedding_dim = rotary_embedding_dim / 2; emscripten::val transpose_options = emscripten::val::object(); // Ensure the input is reshaped to: [batch_size, sequence_length, num_heads, head_size]. @@ -158,178 +158,25 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build "reshape", input, emscripten::val::array(new_shape), reshape_input_options); } - // Split the input to perform the rotary embedding only on a subregion of the tensor if needed. - // The split inputs will be joined back together at the end. - emscripten::val partial_input0 = input; - emscripten::val partial_input1 = emscripten::val::undefined(); - if (head_size != rotary_embedding_dim) { - const std::vector splits{rotary_embedding_dim, head_size - rotary_embedding_dim}; - emscripten::val split_input_options = emscripten::val::object(); - split_input_options.set("label", node_name + "_split_input"); - split_input_options.set("axis", 3); - emscripten::val split = wnn_builder.call( - "split", input, emscripten::val::array(splits), split_input_options); - partial_input0 = split[0]; - partial_input1 = split[1]; - } - - // Split the partial input0 data into 2 equal parts. - // Firstly reshape the partial input0. - const std::vector new_partial_input0_shape = - interleaved ? std::vector({batch_size, sequence_length, num_heads, half_rotary_embedding_dim, 2}) - : std::vector({batch_size, sequence_length, num_heads, 2, half_rotary_embedding_dim}); - emscripten::val reshape_partial_input0_options = emscripten::val::object(); - reshape_partial_input0_options.set("label", node_name + "_reshape_partial_input0"); - partial_input0 = wnn_builder.call( - "reshape", partial_input0, emscripten::val::array(new_partial_input0_shape), reshape_partial_input0_options); - // Split partial input0. - const int split_axis = interleaved ? 4 : 3; - emscripten::val split_partial_input0_options = emscripten::val::object(); - split_partial_input0_options.set("label", node_name + "_split_partial_input0"); - split_partial_input0_options.set("axis", split_axis); - emscripten::val split_partial_input0 = wnn_builder.call( - "split", partial_input0, 2, split_partial_input0_options); - - // Swap the two halves and join them together. - emscripten::val concat_partial_input0_options = emscripten::val::object(); - concat_partial_input0_options.set("label", node_name + "_concat_partial_input0"); - emscripten::val concated_partial_input0 = wnn_builder.call( - "concat", split_partial_input0.call("reverse"), split_axis, concat_partial_input0_options); - - if (position_ids_is_offset) { - // We generate a sequence from 0 to sequence_length and add the offset to it. - const std::vector position_ids_range_shape = {1, sequence_length}; - std::string typed_array_name = "BigInt64Array"; - int position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - const bool is_int64_supported = model_builder.IsInt64Supported(); - if (!is_int64_supported) { - // Int64 is not supported by current context, use int32 instead. - typed_array_name = "Int32Array"; - position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT32; - } - emscripten::val position_ids_range_buffer = emscripten::val::global(typed_array_name.c_str()).new_(sequence_length); - for (uint32_t i = 0; i < sequence_length; i++) { - position_ids_range_buffer.set(i, is_int64_supported ? emscripten::val::global("BigInt")(i) : emscripten::val(i)); - } - emscripten::val position_ids_range_desc = emscripten::val::object(); - position_ids_range_desc.set("shape", emscripten::val::array(position_ids_range_shape)); - position_ids_range_desc.set("dimensions", emscripten::val::array(position_ids_range_shape)); - ORT_RETURN_IF_NOT(SetWebnnDataType(position_ids_range_desc, position_ids_data_type), - "WebNN backend does not support data type: ", position_ids_data_type); - emscripten::val position_ids_range = wnn_builder.call( - "constant", position_ids_range_desc, position_ids_range_buffer); - // Add the offset to the sequence. - emscripten::val position_ids_add_range_options = emscripten::val::object(); - position_ids_add_range_options.set("label", node_name + "_position_ids_add_range"); - position_ids = wnn_builder.call( - "add", position_ids, position_ids_range, position_ids_add_range_options); - } - - // Gather the cosine/sine values based on the position_ids (if it presents). - emscripten::val gather_cos = cos_cache; - emscripten::val gather_sin = sin_cache; - if (has_position_ids) { - emscripten::val gather_cos_sin_options = emscripten::val::object(); - gather_cos_sin_options.set("label", node_name + "_gather_cos_sin"); - gather_cos_sin_options.set("axis", 0); - gather_cos = wnn_builder.call("gather", gather_cos, position_ids, gather_cos_sin_options); - gather_sin = wnn_builder.call("gather", gather_sin, position_ids, gather_cos_sin_options); - } - - // If it is full rotation, we need to slice the gathered cosine/sine - // to get the shape [batch_size, sequence_length, rotary_embedding_dim / 2]. - if (cos_cache_shape.back() != static_cast(half_rotary_embedding_dim)) { - emscripten::val slice_gather_cos_sin_options = emscripten::val::object(); - slice_gather_cos_sin_options.set("label", node_name + "_slice_gather_cos_sin"); - const std::vector starts{0, 0, 0}; - const std::vector sizes{batch_size, sequence_length, half_rotary_embedding_dim}; - gather_cos = wnn_builder.call("slice", gather_cos, emscripten::val::array(starts), - emscripten::val::array(sizes), slice_gather_cos_sin_options); - gather_sin = wnn_builder.call("slice", gather_sin, emscripten::val::array(starts), - emscripten::val::array(sizes), slice_gather_cos_sin_options); - } - - // Reshape and broadcast them to match the number of heads of the input data. - const std::vector reshaped_cos_sin_shape = - interleaved ? std::vector({batch_size, sequence_length, 1, half_rotary_embedding_dim, 1}) - : std::vector({batch_size, sequence_length, 1, 1, half_rotary_embedding_dim}); - emscripten::val reshape_gather_cos_sin_options = emscripten::val::object(); - reshape_gather_cos_sin_options.set("label", node_name + "_reshape_gather_cos_sin"); - gather_cos = wnn_builder.call( - "reshape", gather_cos, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_sin_options); - gather_sin = wnn_builder.call( - "reshape", gather_sin, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_sin_options); - - // Multiply the non-rotated data with the cosine and the rotated data with the sine. - emscripten::val mul_cos_options = emscripten::val::object(); - mul_cos_options.set("label", node_name + "_mul_cos"); - emscripten::val mul_cos = wnn_builder.call( - "mul", partial_input0, gather_cos, mul_cos_options); - emscripten::val mul_sin_options = emscripten::val::object(); - mul_sin_options.set("label", node_name + "_mul_sin"); - emscripten::val mul_sin = wnn_builder.call( - "mul", concated_partial_input0, gather_sin, mul_sin_options); - - // Create a vector that contains the sign values {-1, 1}. - emscripten::val sign_buffer = emscripten::val::undefined(); - const std::vector sign_shape = interleaved ? std::vector({1, 1, 1, 2}) - : std::vector({1, 1, 2, 1}); - emscripten::val sign_constant_desc = emscripten::val::object(); - sign_constant_desc.set("shape", emscripten::val::array(sign_shape)); - sign_constant_desc.set("dimensions", emscripten::val::array(sign_shape)); - ORT_RETURN_IF_NOT(SetWebnnDataType(sign_constant_desc, input_data_type), - "WebNN backend does not support data type: ", input_data_type); - if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - sign_buffer = emscripten::val::global("Float32Array").new_(2); - sign_buffer.set(0, -1.0f); - sign_buffer.set(1, 1.0f); - } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is available - use Float16Array. - sign_buffer = emscripten::val::global("Float16Array").new_(2); - sign_buffer.set(0, -1.0f); - sign_buffer.set(1, 1.0f); - } else { - // Float16Array is not available - use Uint16Array instead. - sign_buffer = emscripten::val::global("Uint16Array").new_(2); - sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); - sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type); - } - emscripten::val sign_constant = wnn_builder.call("constant", sign_constant_desc, sign_buffer); - - // Multiply the broadcasted sign values with the rotated input. - emscripten::val mul_sign_options = emscripten::val::object(); - mul_sign_options.set("label", node_name + "_mul_sign"); - mul_sin = wnn_builder.call("mul", mul_sin, sign_constant, mul_sign_options); - - // Reshape mul_cos and mul_sin to (batch_size, sequence_length, num_heads, rotary_embedding_dim). - const std::vector reshaped_mul_cos_sin_shape = - {batch_size, sequence_length, num_heads, rotary_embedding_dim}; - emscripten::val reshape_mul_cos_sin_options = emscripten::val::object(); - reshape_mul_cos_sin_options.set("label", node_name + "_reshape_mul_cos_sign"); - mul_cos = wnn_builder.call( - "reshape", mul_cos, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options); - mul_sin = wnn_builder.call( - "reshape", mul_sin, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options); - - // Add the multiplied cos and sin values together. - emscripten::val add_mul_cos_sin_options = emscripten::val::object(); - add_mul_cos_sin_options.set("label", node_name + "_add_mul_cos_sin"); - emscripten::val output = wnn_builder.call( - "add", mul_cos, mul_sin, add_mul_cos_sin_options); - - // Join the added values with the rest of the input. - if (head_size != rotary_embedding_dim) { - emscripten::val concat_back_input_options = emscripten::val::object(); - concat_back_input_options.set("label", node_name + "_concat_back_input"); - emscripten::val concat_inputs = emscripten::val::array(); - concat_inputs.call("push", output); - concat_inputs.call("push", partial_input1); - output = wnn_builder.call("concat", concat_inputs, 3, concat_back_input_options); - } + // Apply rotary embedding using the helper function + emscripten::val output; + ORT_RETURN_IF_ERROR(ApplyRotaryEmbedding( + model_builder, + node_name, + input, + cos_cache, + sin_cache, + position_ids, + input_data_type, + batch_size, + sequence_length, + num_heads, + head_size, + rotary_embedding_dim, + interleaved, + has_position_ids, + position_ids_is_offset, + output)); if (input_is_4d) { // The output is in 4D shape, we need to transpose it back to the original shape.