Skip to content

Commit b05b367

Browse files
committed
use new 'GATE_UP_TYPE' in mlp kernel
1 parent ed4e9ca commit b05b367

File tree

5 files changed

+56
-25
lines changed

5 files changed

+56
-25
lines changed

src/plugins/intel_cpu/src/nodes/llm_mlp.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
379379
auto K = w_gate.size(1);
380380
auto N = w_gate.size(0);
381381
OPENVINO_ASSERT(w_gate.stride_bytes(0) == w_up.stride_bytes(0));
382-
if (m_config.gate_up_combined) {
382+
if (m_config.gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
383383
N = w_gate.size(0) / 2;
384-
if (m_config.gate_up_swapped) {
384+
if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
385385
// When VariadicSplit output[1] connects to gate instead of up, swap the pointers
386386
gate_up.setup(w_gate.ptr_v(N, 0), w_gate.ptr_v(), w_gate.stride_bytes(0), N * 2, K, config);
387387
} else {
@@ -398,19 +398,18 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
398398
auto* w_scale_gate = pnode->getSrcMemoryAtPort(4)->getDataAs<float>();
399399
auto* w_scale_up = pnode->getSrcMemoryAtPort(5)->getDataAs<float>();
400400
auto* dst = m_w_scale_gateup.ptr<float>();
401-
if (m_config.gate_up_combined) {
401+
if (m_config.gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
402402
w_scale_up = w_scale_gate + N;
403403
}
404404

405-
// When gate_up_combined=true and gate_up_swapped=true, we need to swap the scales
405+
// When gate_up_type is COMBINED_UP_GATE, we need to swap the scales
406406
// to match the swapped weight layout
407407
auto* scale_first = w_scale_gate;
408408
auto* scale_second = w_scale_up;
409-
if (m_config.gate_up_combined && m_config.gate_up_swapped) {
409+
if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
410410
scale_first = w_scale_up;
411411
scale_second = w_scale_gate;
412412
}
413-
414413
for (size_t i = 0; i < N; i += 16) {
415414
memcpy(dst, scale_first + i, 16 * sizeof(float));
416415
dst += 16;

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,24 @@ EnumNames<ov::intel_cpu::LLMMLPNode::ACT_FN>& EnumNames<ov::intel_cpu::LLMMLPNod
2323
return enum_names;
2424
}
2525

26+
template <>
27+
EnumNames<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>& EnumNames<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>::get() {
28+
static auto enum_names = EnumNames<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>(
29+
"op::intel_cpu::LLMMLPNode::GATE_UP_TYPE",
30+
{{"SEPARATE", ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE::SEPARATE},
31+
{"COMBINED_GATE_UP", ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP},
32+
{"COMBINED_UP_GATE", ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE}});
33+
return enum_names;
34+
}
35+
2636
std::ostream& operator<<(std::ostream& os, const ov::intel_cpu::LLMMLPNode::ACT_FN& type) {
2737
return os << as_string(type);
2838
}
2939

40+
std::ostream& operator<<(std::ostream& os, const ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE& type) {
41+
return os << as_string(type);
42+
}
43+
3044
namespace intel_cpu {
3145

3246
bool LLMMLPNode::visit_attributes(ov::AttributeVisitor& visitor) {
@@ -37,8 +51,7 @@ bool LLMMLPNode::visit_attributes(ov::AttributeVisitor& visitor) {
3751
visitor.on_attribute("down_quantized", m_config.down_quantized);
3852
visitor.on_attribute("hidden_size", m_config.hidden_size);
3953
visitor.on_attribute("up_size", m_config.up_size);
40-
visitor.on_attribute("gate_up_combined", m_config.gate_up_combined);
41-
visitor.on_attribute("gate_up_swapped", m_config.gate_up_swapped);
54+
visitor.on_attribute("gate_up_type", m_config.gate_up_type);
4255
visitor.finish_structure();
4356
return true;
4457
}

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@ class LLMMLPNode : public ov::op::Op {
2525

2626
enum class ACT_FN : uint8_t { SILU = 0, GELU = 1 };
2727

28+
enum class GATE_UP_TYPE : uint8_t {
29+
SEPARATE = 0, // separate gate and up projections
30+
COMBINED_GATE_UP = 1, // combined weights, gate first (normal)
31+
COMBINED_UP_GATE = 2 // combined weights, up first (swapped)
32+
};
33+
2834
struct Config {
2935
ACT_FN act;
3036
bool gate_up_quantized;
3137
bool down_quantized;
3238
int hidden_size;
3339
int up_size;
34-
bool gate_up_combined;
35-
bool gate_up_swapped; // true when VariadicSplit output[1] connects to gate instead of up
40+
GATE_UP_TYPE gate_up_type;
3641
};
3742

3843
// args:
@@ -70,6 +75,17 @@ class AttributeAdapter<ov::intel_cpu::LLMMLPNode::ACT_FN>
7075
OPENVINO_RTTI("AttributeAdapter<ov::intel_cpu::LLMMLPNode::ACT_FN>");
7176
};
7277

78+
template <>
79+
class AttributeAdapter<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>
80+
: public EnumAttributeAdapterBase<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE> {
81+
public:
82+
explicit AttributeAdapter(ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE& value)
83+
: EnumAttributeAdapterBase<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>(value) {}
84+
85+
OPENVINO_RTTI("AttributeAdapter<ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE>");
86+
};
87+
7388
std::ostream& operator<<(std::ostream& s, const ov::intel_cpu::LLMMLPNode::ACT_FN& type);
89+
std::ostream& operator<<(std::ostream& s, const ov::intel_cpu::LLMMLPNode::GATE_UP_TYPE& type);
7490

7591
} // namespace ov

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fusion.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
123123
const auto& pattern_map = m.get_pattern_value_map();
124124
auto root = m.get_match_root();
125125

126-
// Check VariadicSplit output connections in combined mode
127-
bool gate_up_swapped = false;
126+
// Determine gate_up_type based on pattern matching
127+
LLMMLPNode::GATE_UP_TYPE gate_up_type = LLMMLPNode::GATE_UP_TYPE::SEPARATE;
128128
if (pattern_map.count(gate_up_proj_split)) {
129129
auto mlp_gated_up_node = pattern_map.at(mlp_gated_up).get_node_shared_ptr();
130130
auto input0 = mlp_gated_up_node->input_value(0);
@@ -134,9 +134,10 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
134134
// Since pattern matching succeeded, we know one of the outputs connects to Multiply
135135
if ((input0.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input0.get_index() == 0) ||
136136
(input1.get_node() == pattern_map.at(gate_up_proj_split).get_node() && input1.get_index() == 0)) {
137-
gate_up_swapped = true;
137+
gate_up_type = LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE; // swapped case
138+
} else {
139+
gate_up_type = LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP; // normal combined case
138140
}
139-
// Otherwise, it's the normal case where output[1] connects to Multiply
140141
}
141142

142143
auto src = pattern_map.at(input);
@@ -151,17 +152,20 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
151152
// down projection is harder to quantize w/o causing accuracy problem, so it may be un-quantized instead
152153
bool is_gate_up_quantized_int8 = false;
153154
bool is_down_proj_int8 = false;
154-
bool is_gate_up_combined = false;
155155
if (pattern_map.count(gate_up_proj_weight_const_i8) > 0 && pattern_map.count(down_proj_weight_compressed) > 0) {
156156
// gate-up combined & quantized
157157
is_gate_up_quantized_int8 = true;
158-
is_gate_up_combined = true;
158+
gate_up_type = (gate_up_type == LLMMLPNode::GATE_UP_TYPE::SEPARATE)
159+
? LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP
160+
: gate_up_type;
159161
gate_proj_w = pattern_map.at(gate_up_proj_weight_const_i8);
160162
up_proj_w = pattern_map.at(gate_up_proj_weight_const_i8);
161163
down_proj_w = pattern_map.at(down_proj_weight_compressed);
162164
} else if (pattern_map.count(gate_up_proj_weight) > 0 && pattern_map.count(down_proj_weight_compressed) > 0) {
163165
// gate-up combined
164-
is_gate_up_combined = true;
166+
gate_up_type = (gate_up_type == LLMMLPNode::GATE_UP_TYPE::SEPARATE)
167+
? LLMMLPNode::GATE_UP_TYPE::COMBINED_GATE_UP
168+
: gate_up_type;
165169
gate_proj_w = pattern_map.at(gate_up_proj_weight);
166170
up_proj_w = pattern_map.at(gate_up_proj_weight);
167171
down_proj_w = pattern_map.at(down_proj_weight_compressed);
@@ -224,7 +228,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
224228
return false;
225229
}
226230

227-
auto up_size = is_gate_up_combined ? (up_shape[0] / 2) : (up_shape[0]);
231+
auto up_size = (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) ? (up_shape[0] / 2) : (up_shape[0]);
228232
auto down_size = up_shape[1];
229233
if (down_shape[0] != down_size) {
230234
return false;
@@ -240,8 +244,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
240244
cfg.down_quantized = is_down_proj_int8;
241245
cfg.hidden_size = down_size;
242246
cfg.up_size = up_size;
243-
cfg.gate_up_combined = is_gate_up_combined;
244-
cfg.gate_up_swapped = gate_up_swapped;
247+
cfg.gate_up_type = gate_up_type;
245248

246249
if (pattern_map.count(mlp_silu_gate) > 0) {
247250
cfg.act = LLMMLPNode::ACT_FN::SILU;
@@ -266,7 +269,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
266269
new_args.push_back(up_proj_w);
267270
new_args.push_back(down_proj_w);
268271
if (is_gate_up_quantized_int8) {
269-
if (is_gate_up_combined) {
272+
if (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
270273
new_args.push_back(pattern_map.at(gate_up_proj_weight_scales_per_OC));
271274
new_args.push_back(pattern_map.at(gate_up_proj_weight_scales_per_OC));
272275
} else {
@@ -284,7 +287,7 @@ ov::intel_cpu::MLPFusionPass::MLPFusionPass() {
284287
ov::copy_runtime_info(
285288
{pattern_map.at(gate_act).get_node_shared_ptr(), pattern_map.at(down_proj).get_node_shared_ptr()},
286289
new_node);
287-
if (is_gate_up_combined) {
290+
if (gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
288291
ov::copy_runtime_info({pattern_map.at(gate_up_proj).get_node_shared_ptr()}, new_node);
289292
} else {
290293
ov::copy_runtime_info({pattern_map.at(mlp_gate_proj).get_node_shared_ptr(),

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
9797
ov::Output<ov::Node> up_output;
9898

9999
if (param.use_swapped_outputs) {
100-
// Create pattern with swapped VariadicSplit outputs to test gate_up_swapped support
100+
// Create pattern with swapped VariadicSplit outputs to test COMBINED_UP_GATE type
101101
ov::test::utils::InputGenerateData in_data;
102102
in_data.start_from = -0.5;
103103
in_data.range = 1.0;
@@ -120,7 +120,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
120120
auto axis_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, -1);
121121
auto gate_up_split = std::make_shared<ov::op::v1::VariadicSplit>(gate_up_proj, axis_const, split_lengths);
122122

123-
// Swap outputs to test gate_up_swapped support
123+
// Swap outputs to test COMBINED_UP_GATE type
124124
auto gate_part = gate_up_split->output(1); // activation on output[1]
125125
if (param.act_type == "Swish")
126126
gate_act = std::make_shared<ov::op::v4::Swish>(gate_part);
@@ -195,7 +195,7 @@ const std::vector<LLMMLPFusionParams> mlp_params = {
195195
{ishape, 4096 / 4, 11008 / 4, "Swish", false, false},
196196
{ishape, 4096 / 4, 11008 / 4, "Swish", true, false},
197197

198-
// Test case with swapped VariadicSplit outputs (should fuse with gate_up_swapped=true)
198+
// Test case with swapped VariadicSplit outputs (should fuse with COMBINED_UP_GATE type)
199199
{ishape, 4096 / 4, 11008 / 4, "Gelu", false, true},
200200
};
201201

0 commit comments

Comments
 (0)