Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions onnxruntime/core/optimizer/group_query_attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,48 @@
return rotary_node_1 == nullptr || rotary_node_2 == nullptr || q_node == nullptr || k_node == nullptr || v_node == nullptr;
}

static bool NodeArgExists(const NodeArg* node_arg) {
return node_arg != nullptr && node_arg->Exists();
}

static bool TryGetRotaryEmbeddingArgs(Node& rotary_node,
NodeArg*& cos_cache_arg,
NodeArg*& sin_cache_arg,
NodeArg*& position_ids_arg) {
if (rotary_node.OpType() != "RotaryEmbedding") {
return false;
}

auto& input_defs = rotary_node.MutableInputDefs();
if (rotary_node.Domain() == kMSDomain) {
// com.microsoft.RotaryEmbedding inputs:
// input, position_ids, cos_cache, sin_cache
if (input_defs.size() < 4 || !NodeArgExists(input_defs[2]) || !NodeArgExists(input_defs[3])) {
return false;
}
cos_cache_arg = input_defs[2];
sin_cache_arg = input_defs[3];
return true;
}

if (rotary_node.Domain() == kOnnxDomain) {
// ONNX RotaryEmbedding inputs:
// X, cos_cache, sin_cache, optional position_ids
// If position_ids is omitted, ONNX RotaryEmbedding uses 3D per-batch caches, which are
// incompatible with GroupQueryAttention's 2D rotary cache inputs.
if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) ||
!NodeArgExists(input_defs[3])) {
return false;
}
cos_cache_arg = input_defs[1];
sin_cache_arg = input_defs[2];

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TryGetRotaryEmbeddingArgs matches an ONNX RotaryEmbedding purely by op type, domain, and input presence, but never inspects the interleaved or rotary_embedding_dim attributes. GQA's do_rotary applies non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and is not set anywhere in this pass). So a node with interleaved=1, or partial rotary (rotary_embedding_dim > 0 with a narrower cos/sin cache), would be silently fused into a GQA that computes a different rotation — a silent numerical mismatch rather than a hard failure.

Consider rejecting the match when interleaved != 0 or rotary_embedding_dim != 0 (full rotary only), or propagate interleaved onto the fused GQA's rotary_interleaved attribute. The same gap exists for the com.microsoft.RotaryEmbedding branch above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by propagating interleaved  to GQA’s rotary_interleaved and rejecting Q/K RotaryEmbedding mismatches.

position_ids_arg = input_defs[3];
return true;
}
Comment thread
xiaoyu-work marked this conversation as resolved.

return false;
}

static void FusePreGQANodes(Graph& graph, Node* q_node, Node* k_node, Node* v_node, Node* rotary_node_1, Node* rotary_node_2, Node* new_node, NodeArg& new_node_output_arg) {
graph_utils::MoveAllNodeInputEdges(graph, *q_node, *new_node);

Expand Down Expand Up @@ -318,6 +360,9 @@

NodeArg* cos_cache_arg = nullptr;
NodeArg* sin_cache_arg = nullptr;
NodeArg* position_ids_arg = nullptr;
bool position_ids_arg_set = false;
bool position_ids_arg_mismatch = false;
NodeArg* past_key_values_key_arg = node.MutableInputDefs()[3];
NodeArg* past_key_values_value_arg = node.MutableInputDefs()[4];
NodeArg* seqlens_k = node.MutableInputDefs()[5];
Expand All @@ -334,7 +379,13 @@
for (auto pre_gqa_node = node.InputNodesBegin(); pre_gqa_node != node.InputNodesEnd(); ++pre_gqa_node) {
Node& rotary_or_v_node = *graph.GetNode(pre_gqa_node->Index());

if (rotary_or_v_node.OpType() == "RotaryEmbedding") {
NodeArg* rotary_cos_cache_arg = nullptr;
NodeArg* rotary_sin_cache_arg = nullptr;
NodeArg* rotary_position_ids_arg = nullptr;
if (TryGetRotaryEmbeddingArgs(rotary_or_v_node,
rotary_cos_cache_arg,
rotary_sin_cache_arg,
rotary_position_ids_arg)) {
if (!rotary_node_1) {
rotary_node_1 = &rotary_or_v_node;
} else {
Expand All @@ -358,18 +409,27 @@
}

if (cos_cache_arg == nullptr) {
cos_cache_arg = rotary_or_v_node.MutableInputDefs()[2];
cos_cache_arg = rotary_cos_cache_arg;
}

if (sin_cache_arg == nullptr) {
sin_cache_arg = rotary_or_v_node.MutableInputDefs()[3];
sin_cache_arg = rotary_sin_cache_arg;
}

if (!position_ids_arg_set) {
position_ids_arg = rotary_position_ids_arg;
position_ids_arg_set = true;
} else if (position_ids_arg != rotary_position_ids_arg) {
position_ids_arg_mismatch = true;
}
} else if (rotary_or_v_node.OpType() == "MatMulNBits" || rotary_or_v_node.OpType() == "MatMul") {
v_node = &rotary_or_v_node;
}
}

if (CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node)) {
if (position_ids_arg_mismatch ||
CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node) ||
cos_cache_arg == nullptr || sin_cache_arg == nullptr) {
// Some of the required pre-GQA nodes required for fusion were not retrieved,
// this can be expected if the model has extra nodes in between MatMuls and rotary embeddings.
continue;
Expand Down Expand Up @@ -493,7 +553,7 @@
std::string empty_name;
auto& empty_node_arg = graph.GetOrCreateNodeArg(empty_name, nullptr);

const std::array gqa_input_defs{
std::vector<NodeArg*> gqa_input_defs{

Check warning on line 556 in onnxruntime/core/optimizer/group_query_attention_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/group_query_attention_fusion.cc:556: Add #include <vector> for vector<> [build/include_what_you_use] [4]
&matmul_or_nbits_output,
&empty_node_arg,
&empty_node_arg,
Expand All @@ -503,10 +563,16 @@
total_seq_len,
cos_cache_arg,
sin_cache_arg};
if (position_ids_arg != nullptr) {
gqa_input_defs.push_back(position_ids_arg);
}

auto& gqa_input_args = node.MutableInputArgsCount();
gqa_input_args[7] = 1;
gqa_input_args[8] = 1;
if (position_ids_arg != nullptr) {
gqa_input_args[9] = 1;
}

// Switch GQA input defs from unfused into the fused form.
auto& gqa_node_input_defs = node.MutableInputDefs();
Expand Down
143 changes: 143 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,127 @@ static void TestGQAFusion(const std::basic_string<ORTCHAR_T>& file_path, int mat
ASSERT_TRUE(op_to_count["com.microsoft.GroupQueryAttention"] == 1);
}

static void BuildOnnxRotaryEmbeddingGQAFusionGraph(ModelTestBuilder& builder, bool include_position_ids) {
constexpr int64_t batch_size = 1;
constexpr int64_t sequence_length = 2;
constexpr int64_t input_hidden_size = 8;
constexpr int64_t num_heads = 2;
constexpr int64_t kv_num_heads = 1;
constexpr int64_t head_size = 16;
constexpr int64_t q_hidden_size = num_heads * head_size;
constexpr int64_t kv_hidden_size = kv_num_heads * head_size;
constexpr int64_t max_sequence_length = 8;
constexpr int64_t half_rotary_dim = head_size / 2;

auto make_weight = [&builder](int64_t rows, int64_t cols, float value) {
return builder.MakeInitializer<MLFloat16>(
{rows, cols}, std::vector<MLFloat16>(static_cast<size_t>(rows * cols), MLFloat16(value)));
};

NodeArg* input = builder.MakeInput<MLFloat16>({{batch_size, sequence_length, input_hidden_size}});
NodeArg* q_weight = make_weight(input_hidden_size, q_hidden_size, 0.5f);
NodeArg* k_weight = make_weight(input_hidden_size, kv_hidden_size, 0.25f);
NodeArg* v_weight = make_weight(input_hidden_size, kv_hidden_size, 0.125f);

NodeArg* q_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
NodeArg* v_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
builder.AddNode("MatMul", {input, q_weight}, {q_matmul_out});
builder.AddNode("MatMul", {input, k_weight}, {k_matmul_out});
builder.AddNode("MatMul", {input, v_weight}, {v_matmul_out});

const std::vector<int64_t> cache_shape = include_position_ids
? std::vector<int64_t>{max_sequence_length, half_rotary_dim}
: std::vector<int64_t>{batch_size, sequence_length, half_rotary_dim};
NodeArg* cos_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* sin_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* position_ids = include_position_ids
? builder.MakeInput<int64_t>({{batch_size, sequence_length}})
: nullptr;

NodeArg* q_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});

std::vector<NodeArg*> q_rotary_inputs{q_matmul_out, cos_cache, sin_cache};
std::vector<NodeArg*> k_rotary_inputs{k_matmul_out, cos_cache, sin_cache};
if (position_ids != nullptr) {
q_rotary_inputs.push_back(position_ids);
k_rotary_inputs.push_back(position_ids);
}

Node& q_rotary = builder.AddNode("RotaryEmbedding", q_rotary_inputs, {q_rotary_out}, kOnnxDomain);
q_rotary.AddAttribute("num_heads", num_heads);
Node& k_rotary = builder.AddNode("RotaryEmbedding", k_rotary_inputs, {k_rotary_out}, kOnnxDomain);
k_rotary.AddAttribute("num_heads", kv_num_heads);

NodeArg* past_key =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* past_value =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* seqlens_k = builder.MakeInput<int32_t>({{batch_size}});
NodeArg* total_sequence_length = builder.MakeInput<int32_t>({{1}});
NodeArg* gqa_output =
builder.MakeOutput<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});

Node& gqa = builder.AddNode("GroupQueryAttention",
{q_rotary_out, k_rotary_out, v_matmul_out, past_key, past_value,
seqlens_k, total_sequence_length},
{gqa_output},
kMSDomain);
gqa.AddAttribute("num_heads", num_heads);
gqa.AddAttribute("kv_num_heads", kv_num_heads);
}

static Status CheckOnnxRotaryEmbeddingGQAFused(Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 0);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 1);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 10);
TEST_RETURN_IF_NOT(node.InputDefs()[7] != nullptr && node.InputDefs()[7]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[8] != nullptr && node.InputDefs()[8]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[9] != nullptr && node.InputDefs()[9]->Exists());

const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr != attrs.end());
TEST_RETURN_IF_NOT(do_rotary_attr->second.i() == 1);
}

return Status::OK();
}

static Status CheckOnnxRotaryEmbeddingGQANotFused(Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 2);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 3);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 7);
const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr == attrs.end() || do_rotary_attr->second.i() == 0);
}

return Status::OK();
}

static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count,
int skip_ln_count, int cast_count, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
Expand Down Expand Up @@ -796,6 +917,28 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) {
TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_different_head_sizes.onnx", 1, 0, logger_.get());
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQAFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingNoPositionIdsTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, false);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQANotFused));
}

TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());
Expand Down
Loading