Skip to content

Commit 5405515

Browse files
committed
[CPU] Add gate_up_swapped support for LLM MLP fusion
1 parent d818b9c commit 5405515

File tree

5 files changed

+112
-25
lines changed

5 files changed

+112
-25
lines changed

src/plugins/intel_cpu/src/nodes/llm_mlp.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,13 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
381381
OPENVINO_ASSERT(w_gate.stride_bytes(0) == w_up.stride_bytes(0));
382382
if (m_config.gate_up_combined) {
383383
N = w_gate.size(0) / 2;
384-
gate_up.setup(w_gate.ptr_v(), w_up.ptr_v(N, 0), w_up.stride_bytes(0), N * 2, K, config);
384+
if (m_config.gate_up_swapped) {
385+
// When VariadicSplit output[1] connects to gate instead of up, swap the pointers
386+
gate_up.setup(w_gate.ptr_v(N, 0), w_gate.ptr_v(), w_gate.stride_bytes(0), N * 2, K, config);
387+
} else {
388+
// Normal case: VariadicSplit output[1] connects to up
389+
gate_up.setup(w_gate.ptr_v(), w_gate.ptr_v(N, 0), w_gate.stride_bytes(0), N * 2, K, config);
390+
}
385391
} else {
386392
gate_up.setup(w_gate.ptr_v(), w_up.ptr_v(), w_up.stride_bytes(0), N * 2, K, config);
387393
}

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ bool LLMMLPNode::visit_attributes(ov::AttributeVisitor& visitor) {
3838
visitor.on_attribute("hidden_size", m_config.hidden_size);
3939
visitor.on_attribute("up_size", m_config.up_size);
4040
visitor.on_attribute("gate_up_combined", m_config.gate_up_combined);
41+
visitor.on_attribute("gate_up_swapped", m_config.gate_up_swapped);
4142
visitor.finish_structure();
4243
return true;
4344
}

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class LLMMLPNode : public ov::op::Op {
3232
int hidden_size;
3333
int up_size;
3434
bool gate_up_combined;
35+
bool gate_up_swapped; // true when VariadicSplit output[1] connects to gate instead of up
3536
};
3637

3738
// args:

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,23 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
122122
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
123123
const auto& pattern_map = m.get_pattern_value_map();
124124
auto root = m.get_match_root();
125+
126+
// Check VariadicSplit output connections in combined mode
127+
bool gate_up_swapped = false;
128+
if (pattern_map.count(gate_up_proj_split)) {
129+
auto mlp_gated_up_node = pattern_map.at(mlp_gated_up).get_node_shared_ptr();
130+
auto input0 = mlp_gated_up_node->input_value(0);
131+
auto input1 = mlp_gated_up_node->input_value(1);
132+
133+
// Check if VariadicSplit output[0] connects to Multiply (swapped case)
134+
// Since pattern matching succeeded, we know one of the outputs connects to Multiply
135+
if ((input0.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input0.get_index() == 0) ||
136+
(input1.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input1.get_index() == 0)) {
137+
gate_up_swapped = true;
138+
}
139+
// Otherwise, it's the normal case where output[1] connects to Multiply
140+
}
141+
125142
auto src = pattern_map.at(input);
126143
if (!src.get_element_type().is_real()) {
127144
// FakeQuantize, should skip fusion
@@ -224,6 +241,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
224241
cfg.hidden_size = down_size;
225242
cfg.up_size = up_size;
226243
cfg.gate_up_combined = is_gate_up_combined;
244+
cfg.gate_up_swapped = gate_up_swapped;
227245

228246
if (pattern_map.count(mlp_silu_gate) > 0) {
229247
cfg.act = LLMMLPNode::ACT_FN::SILU;

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
#include <vector>
77

88
#include "common_test_utils/ov_tensor_utils.hpp"
9-
#include "openvino/runtime/exec_model_info.hpp"
10-
#include "shared_test_classes/base/ov_subgraph.hpp"
119
#include "openvino/op/convert.hpp"
1210
#include "openvino/op/gelu.hpp"
1311
#include "openvino/op/matmul.hpp"
1412
#include "openvino/op/multiply.hpp"
1513
#include "openvino/op/swish.hpp"
14+
#include "openvino/op/variadic_split.hpp"
15+
#include "openvino/runtime/exec_model_info.hpp"
16+
#include "shared_test_classes/base/ov_subgraph.hpp"
17+
#include "transformations/rt_info/decompression.hpp"
1618

1719
namespace ov {
1820
namespace test {
@@ -23,6 +25,7 @@ struct LLMMLPFusionParams {
2325
size_t up_size;
2426
std::string act_type;
2527
bool use_dynamic_quant;
28+
bool use_swapped_outputs; // true = create pattern with swapped VariadicSplit outputs (should still fuse)
2629
};
2730

2831
class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>, public ov::test::SubgraphBaseTest {
@@ -39,6 +42,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
3942
result << "up_size=" << obj.param.up_size << "_";
4043
result << "act_type=" << obj.param.act_type << "_";
4144
result << "use_dynamic_quant=" << obj.param.use_dynamic_quant << "_";
45+
result << "use_swapped_outputs=" << obj.param.use_swapped_outputs << "_";
4246
result << obj.index;
4347
return result.str();
4448
}
@@ -70,7 +74,8 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
7074
in_data.start_from = 0;
7175
in_data.range = 1;
7276
in_data.resolution = 128;
73-
auto tensor_scale_per_oc = ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
77+
auto tensor_scale_per_oc =
78+
ov::test::utils::create_and_fill_tensor(ov::element::f32, ov::Shape{OC, 1}, in_data);
7479
auto scale_per_oc = std::make_shared<ov::op::v0::Constant>(tensor_scale_per_oc);
7580

7681
auto weight_deq = std::make_shared<ov::op::v1::Multiply>(weight_const_f32, scale_per_oc);
@@ -85,38 +90,89 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
8590
return std::make_shared<ov::op::v0::Constant>(tensor);
8691
};
8792
if (param.use_dynamic_quant)
88-
configuration.insert({ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});
89-
90-
auto gate_weight = create_const(param.up_size, param.down_size, 100);
91-
auto up_weight = create_const(param.up_size, param.down_size, 100);
92-
// down_proj has special cache blocking along K dimension requires lower weight resolution
93-
auto down_weight = create_const(param.down_size, param.up_size, 16);
94-
95-
auto gate_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_weight, false, true);
96-
auto up_proj = std::make_shared<ov::op::v0::MatMul>(src, up_weight, false, true);
93+
configuration.insert(
94+
{ov::hint::dynamic_quantization_group_size.name(), std::numeric_limits<uint64_t>::max()});
9795

9896
std::shared_ptr<Node> gate_act;
99-
if (param.act_type == "Swish")
100-
gate_act = std::make_shared<ov::op::v4::Swish>(gate_proj);
101-
if (param.act_type == "Gelu")
102-
gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);
97+
ov::Output<ov::Node> up_output;
10398

104-
auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_proj);
99+
if (param.use_swapped_outputs) {
100+
// Create pattern with swapped VariadicSplit outputs to test gate_up_swapped support
101+
ov::test::utils::InputGenerateData in_data;
102+
in_data.start_from = -0.5;
103+
in_data.range = 1.0;
104+
in_data.resolution = 16;
105+
106+
// Combined gate_up weight in FP16 format
107+
auto tensor_f16 = ov::test::utils::create_and_fill_tensor(ov::element::f16,
108+
ov::Shape{param.up_size * 2, param.down_size},
109+
in_data);
110+
auto gate_up_weight_f16 = std::make_shared<ov::op::v0::Constant>(tensor_f16);
111+
auto gate_up_weight_f32 = std::make_shared<ov::op::v0::Convert>(gate_up_weight_f16, ov::element::f32);
112+
mark_as_decompression(gate_up_weight_f32);
113+
114+
auto gate_up_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_up_weight_f32, false, true);
115+
116+
auto split_lengths = std::make_shared<ov::op::v0::Constant>(
117+
ov::element::i32,
118+
ov::Shape{2},
119+
std::vector<int32_t>{static_cast<int32_t>(param.up_size), static_cast<int32_t>(param.up_size)});
120+
auto axis_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, -1);
121+
auto gate_up_split = std::make_shared<ov::op::v1::VariadicSplit>(gate_up_proj, axis_const, split_lengths);
122+
123+
// Swap outputs to test gate_up_swapped support
124+
auto gate_part = gate_up_split->output(1); // activation on output[1]
125+
if (param.act_type == "Swish")
126+
gate_act = std::make_shared<ov::op::v4::Swish>(gate_part);
127+
if (param.act_type == "Gelu")
128+
gate_act = std::make_shared<ov::op::v7::Gelu>(gate_part);
129+
130+
auto up_part = gate_up_split->output(0); // up branch from output[0] (swapped case)
131+
up_output = up_part;
132+
} else {
133+
// Standard separate weights pattern
134+
auto gate_weight = create_const(param.up_size, param.down_size, 100);
135+
auto up_weight = create_const(param.up_size, param.down_size, 100);
136+
137+
auto gate_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_weight, false, true);
138+
auto up_proj = std::make_shared<ov::op::v0::MatMul>(src, up_weight, false, true);
139+
140+
if (param.act_type == "Swish")
141+
gate_act = std::make_shared<ov::op::v4::Swish>(gate_proj);
142+
if (param.act_type == "Gelu")
143+
gate_act = std::make_shared<ov::op::v7::Gelu>(gate_proj);
144+
145+
up_output = up_proj;
146+
}
147+
148+
// Create compressed down projection weight
149+
ov::test::utils::InputGenerateData down_data;
150+
down_data.start_from = -0.5;
151+
down_data.range = 1;
152+
down_data.resolution = 16;
153+
auto tensor_f16_down = ov::test::utils::create_and_fill_tensor(ov::element::f16,
154+
ov::Shape{param.down_size, param.up_size},
155+
down_data);
156+
auto down_weight_f16 = std::make_shared<ov::op::v0::Constant>(tensor_f16_down);
157+
auto down_weight = std::make_shared<ov::op::v0::Convert>(down_weight_f16, ov::element::f32);
158+
159+
auto gate_up = std::make_shared<ov::op::v1::Multiply>(gate_act, up_output);
105160
auto output = std::make_shared<ov::op::v0::MatMul>(gate_up, down_weight, false, true);
106161

107162
function = std::make_shared<ov::Model>(ov::OutputVector{output}, ov::ParameterVector{src});
108163
}
109164

110165
void check_results() {
111166
auto exec_model = compiledModel.get_runtime_model();
112-
113167
int fused_node_found = 0;
114168
for (const auto& n : exec_model->get_ordered_ops()) {
115169
auto layer_type = n->get_rt_info().at(ov::exec_model_info::LAYER_TYPE).as<std::string>();
116170
if (layer_type == "LLMMLP")
117171
fused_node_found++;
118172
}
119-
ASSERT_EQ(fused_node_found, 1);
173+
174+
// Both normal and swapped cases should fuse successfully
175+
ASSERT_EQ(fused_node_found, 1) << "Fusion should occur with valid MLP patterns (both normal and swapped cases)";
120176
}
121177
};
122178

@@ -129,13 +185,18 @@ TEST_P(LLMMLPFusionTest, CompareWithRefs) {
129185

130186
namespace {
131187

132-
static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4}, {ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};
188+
static ov::test::InputShape ishape{ov::PartialShape{-1, -1, 4096 / 4},
189+
{ov::Shape{1, 8, 4096 / 4}, ov::Shape{5, 37, 4096 / 4}}};
133190

134191
const std::vector<LLMMLPFusionParams> mlp_params = {
135-
{ishape, 4096 / 4, 11008 / 4, "Gelu", false},
136-
{ishape, 4096 / 4, 11008 / 4, "Gelu", true},
137-
{ishape, 4096 / 4, 11008 / 4, "Swish", false},
138-
{ishape, 4096 / 4, 11008 / 4, "Swish", true},
192+
// Standard separate weights cases (should all fuse successfully)
193+
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, false},
194+
{ishape, 4096 / 4, 11008 / 4, "Gelu", true, false},
195+
{ishape, 4096 / 4, 11008 / 4, "Swish", false, false},
196+
{ishape, 4096 / 4, 11008 / 4, "Swish", true, false},
197+
198+
// Test case with swapped VariadicSplit outputs (should fuse with gate_up_swapped=true)
199+
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, true},
139200
};
140201

141202
INSTANTIATE_TEST_SUITE_P(smoke_LLMMLPFusion,

0 commit comments

Comments
 (0)