Skip to content

Commit 67ed238

Browse files
committed
change constant_path logic to avoid Matmul nodes zeros value output issue on CPU
1 parent 2bdb25a commit 67ed238

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>&
280280

281281
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
282282

283-
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);
283+
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output,
284+
const std::unordered_set<std::type_index>& break_node_types = {});
284285

285286
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
286287

src/common/transformations/src/transformations/utils/utils.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,23 @@ bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int6
489489
return false;
490490
}
491491

492-
bool is_on_constant_path(const ov::Output<ov::Node>& output) {
492+
bool is_on_constant_path(const ov::Output<ov::Node>& output,
493+
const std::unordered_set<std::type_index>& break_node_types) {
493494
auto status = true;
494495
std::deque<ov::Node*> nodes_to_calculate = {output.get_node()};
495496

496497
while (status && !nodes_to_calculate.empty()) {
497498
auto current_node = nodes_to_calculate.front();
498499
nodes_to_calculate.pop_front();
499500

501+
// Check if the current node matches any type in break_node_types
502+
if (!break_node_types.empty()) {
503+
std::type_index current_type(typeid(*current_node));
504+
if (break_node_types.find(current_type) != break_node_types.end()) {
505+
return false;
506+
}
507+
}
508+
500509
if (current_node->get_input_size() == 0 && !ov::is_type<ov::op::v0::Constant>(current_node)) {
501510
status = false;
502511
} else {

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/convert_matmul_to_fc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "openvino/core/type/element_type.hpp"
1111
#include "openvino/op/convert.hpp"
1212
#include "openvino/op/matmul.hpp"
13+
#include "openvino/op/random_uniform.hpp"
1314
#include "openvino/op/transpose.hpp"
1415
#include "openvino/pass/pattern/op/wrap_type.hpp"
1516
#include "ov_ops/fully_connected.hpp"
@@ -19,7 +20,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
1920
MATCHER_SCOPE(ConvertMatMulToFC);
2021
auto activations_m = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
2122
auto weights_path = [](const ov::Output<ov::Node>& output) {
22-
return ov::op::util::is_on_constant_path(output);
23+
return ov::op::util::is_on_constant_path(output, {typeid(ov::op::v8::RandomUniform)});
2324
};
2425
auto weights_m = ov::pass::pattern::any_input(weights_path);
2526
auto matmul_m = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({activations_m, weights_m},

src/plugins/intel_cpu/tests/unit/transformations/convert_matmul_test.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "openvino/opsets/opset1_decl.hpp"
1010
#include "openvino/opsets/opset3_decl.hpp"
1111
#include "openvino/opsets/opset7_decl.hpp"
12+
#include "openvino/opsets/opset8_decl.hpp"
1213
#include <openvino/pass/manager.hpp>
1314
#include <ov_ops/type_relaxed.hpp>
1415
#include <transformations/cpu_opset/common/pass/convert_matmul_to_fc.hpp>
@@ -25,6 +26,7 @@
2526
#include "openvino/op/shape_of.hpp"
2627
#include "openvino/op/subtract.hpp"
2728
#include "openvino/op/transpose.hpp"
29+
#include "openvino/op/random_uniform.hpp"
2830

2931
using namespace testing;
3032
using namespace ov::intel_cpu;
@@ -543,3 +545,38 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_compressed_u8_weights) {
543545
model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
544546
}
545547
}
548+
549+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_WithRandomUniform) {
550+
{
551+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, -1});
552+
553+
auto random_uniform_shape = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 2});
554+
auto random_uniform_min = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {0.0});
555+
auto random_uniform_max = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {1.0});
556+
auto random_uniform = std::make_shared<ov::op::v8::RandomUniform>(random_uniform_shape,
557+
random_uniform_min,
558+
random_uniform_max,
559+
ov::element::f32);
560+
561+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, random_uniform, false, false);
562+
563+
model = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{input1});
564+
565+
manager.register_pass<ConvertMatMulToFC>();
566+
}
567+
{
568+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, -1});
569+
570+
auto random_uniform_shape = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 2});
571+
auto random_uniform_min = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {0.0});
572+
auto random_uniform_max = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1}, {1.0});
573+
auto random_uniform = std::make_shared<ov::op::v8::RandomUniform>(random_uniform_shape,
574+
random_uniform_min,
575+
random_uniform_max,
576+
ov::element::f32);
577+
578+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, random_uniform, false, false);
579+
580+
model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{input1});
581+
}
582+
}

0 commit comments

Comments
 (0)