Skip to content

Commit 7097b71

Browse files
committed
change constant_path logic to avoid Matmul nodes zeros value output issue on CPU
1 parent 7a242a5 commit 7097b71

File tree

4 files changed

+49
-3
lines changed

4 files changed

+49
-3
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>&
280280

281281
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
282282

283-
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);
283+
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output,
284+
const std::unordered_set<std::type_index>& break_node_types = {});
284285

285286
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
286287

src/common/transformations/src/transformations/utils/utils.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,23 @@ bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int6
489489
return false;
490490
}
491491

492-
bool is_on_constant_path(const ov::Output<ov::Node>& output) {
492+
bool is_on_constant_path(const ov::Output<ov::Node>& output,
493+
const std::unordered_set<std::type_index>& break_node_types) {
493494
auto status = true;
494495
std::deque<ov::Node*> nodes_to_calculate = {output.get_node()};
495496

496497
while (status && !nodes_to_calculate.empty()) {
497498
auto current_node = nodes_to_calculate.front();
498499
nodes_to_calculate.pop_front();
499500

501+
// Check if the current node matches any type in break_node_types
502+
if (!break_node_types.empty()) {
503+
std::type_index current_type(typeid(*current_node));
504+
if (break_node_types.find(current_type) != break_node_types.end()) {
505+
return false;
506+
}
507+
}
508+
500509
if (current_node->get_input_size() == 0 && !ov::is_type<ov::op::v0::Constant>(current_node)) {
501510
status = false;
502511
} else {

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/convert_matmul_to_fc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "openvino/core/type/element_type.hpp"
1111
#include "openvino/op/convert.hpp"
1212
#include "openvino/op/matmul.hpp"
13+
#include "openvino/op/random_uniform.hpp"
1314
#include "openvino/op/transpose.hpp"
1415
#include "openvino/pass/pattern/op/wrap_type.hpp"
1516
#include "ov_ops/fully_connected.hpp"
@@ -19,7 +20,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
1920
MATCHER_SCOPE(ConvertMatMulToFC);
2021
auto activations_m = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
2122
auto weights_path = [](const ov::Output<ov::Node>& output) {
22-
return ov::op::util::is_on_constant_path(output);
23+
return ov::op::util::is_on_constant_path(output, {typeid(ov::op::v8::RandomUniform)});
2324
};
2425
auto weights_m = ov::pass::pattern::any_input(weights_path);
2526
auto matmul_m = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({activations_m, weights_m},

src/plugins/intel_cpu/tests/unit/transformations/convert_matmul_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,38 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_compressed_u8_weights) {
537537
model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
538538
}
539539
}
540+
541+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_WithRandomUniform) {
542+
{
543+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, -1});
544+
545+
auto random_uniform_shape = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 2});
546+
auto random_uniform_min = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {0.0});
547+
auto random_uniform_max = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {1.0});
548+
auto random_uniform = std::make_shared<ov::op::v8::RandomUniform>(random_uniform_shape,
549+
random_uniform_min,
550+
random_uniform_max,
551+
ov::element::f32);
552+
553+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, random_uniform, false, false);
554+
555+
model = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{input1});
556+
557+
manager.register_pass<ConvertMatMulToFC>();
558+
}
559+
{
560+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, -1});
561+
562+
auto random_uniform_shape = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 2});
563+
auto random_uniform_min = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {0.0});
564+
auto random_uniform_max = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {1.0});
565+
auto random_uniform = std::make_shared<ov::op::v8::RandomUniform>(random_uniform_shape,
566+
random_uniform_min,
567+
random_uniform_max,
568+
ov::element::f32);
569+
570+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, random_uniform, false, false);
571+
572+
model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{input1});
573+
}
574+
}

0 commit comments

Comments
 (0)