Skip to content

Commit b3f8561

Browse files
committed
Apply suggestions from code review
1 parent 2911560 commit b3f8561

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/plugins/intel_cpu/src/nodes/llm_mlp.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,10 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
382382
if (m_config.gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
383383
N = w_gate.size(0) / 2;
384384
if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
385-
// When VariadicSplit output[1] connects to gate instead of up, swap the pointers
385+
// COMBINED_UP_GATE: VariadicSplit output[0] connects to up, output[1] connects to gate
386386
gate_up.setup(w_gate.ptr_v(N, 0), w_gate.ptr_v(), w_gate.stride_bytes(0), N * 2, K, config);
387387
} else {
388-
// Normal case: VariadicSplit output[1] connects to up
388+
// COMBINED_GATE_UP: VariadicSplit output[0] connects to gate, output[1] connects to up
389389
gate_up.setup(w_gate.ptr_v(), w_gate.ptr_v(N, 0), w_gate.stride_bytes(0), N * 2, K, config);
390390
}
391391
} else {
@@ -407,8 +407,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
407407
auto* scale_first = w_scale_gate;
408408
auto* scale_second = w_scale_up;
409409
if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
410-
scale_first = w_scale_up;
411-
scale_second = w_scale_gate;
410+
std::swap(scale_first, scale_second);
412411
}
413412
for (size_t i = 0; i < N; i += 16) {
414413
memcpy(dst, scale_first + i, 16 * sizeof(float));

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mlp_fusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class LLMMLPFusionTest : public testing::WithParamInterface<LLMMLPFusionParams>,
109109
in_data);
110110
auto gate_up_weight_f16 = std::make_shared<ov::op::v0::Constant>(tensor_f16);
111111
auto gate_up_weight_f32 = std::make_shared<ov::op::v0::Convert>(gate_up_weight_f16, ov::element::f32);
112+
// Mark as decompression to prevent constant folding optimization and avoid pattern mismatch
112113
mark_as_decompression(gate_up_weight_f32);
113114

114115
auto gate_up_proj = std::make_shared<ov::op::v0::MatMul>(src, gate_up_weight_f32, false, true);

0 commit comments

Comments
 (0)