Skip to content

Commit 5eb325c

Browse files
committed
Optimize tail 'Convert' nodes time cost in f16 precision mark-up transformation
1 parent 6b6eb4b commit 5eb325c

File tree

2 files changed

+252
-15
lines changed

2 files changed

+252
-15
lines changed

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 171 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,20 +2004,18 @@ void Graph::EnforceInferencePrecision() {
20042004
CPU_DEBUG_CAP_ENABLE(EnforceInferPrcDebug inferPrecDebug);
20052005

20062006
const auto inferPrec = getConfig().inferencePrecision;
2007-
if (one_of(inferPrec, element::f32, element::dynamic, ov::element::f16, element::dynamic)) {
2007+
if (one_of(inferPrec, element::f32, element::dynamic)) {
20082008
return; // nothing to do, only precision reduction is currently allowed
20092009
}
2010-
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
20112010
if (inferPrec == ov::element::f16) {
2011+
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
20122012
return; // precision of configured by ov::pass::ConvertPrecision
2013-
}
20142013
#endif
2015-
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
2016-
searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
2017-
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2018-
const auto& parent = node->getParentEdgeAt(i)->getParent();
2019-
if (inferPrec == ov::element::bf16) {
2020-
/* list of node types that must be forced to be executed in BF16 precision
2014+
std::function<void(const NodePtr&, std::unordered_map<NodePtr, bool>&)> searchForTailNodes;
2015+
searchForTailNodes = [&](const NodePtr& node, std::unordered_map<NodePtr, bool>& tailNodes) -> void {
2016+
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2017+
const auto& parent = node->getParentEdgeAt(i)->getParent();
2018+
/* list of node types that must be forced to be executed in F16 precision
20212019
* because of performance gains */
20222020
if (one_of(parent->getType(),
20232021
Type::Convolution, // conv nets
@@ -2029,19 +2027,177 @@ void Graph::EnforceInferencePrecision() {
20292027
Type::Interpolate, // super resolution nets
20302028
Type::PagedAttention, // page attention
20312029
Type::QKVProjection,
2032-
Type::LLMMLP)) {
2030+
Type::LLMMLP,
2031+
Type::Pooling)) {
20332032
continue; // stop at significant nodes
20342033
}
2035-
} else if (inferPrec == ov::element::f16) {
2036-
/* list of node types that must be forced to be executed in FP16 precision
2034+
const auto res = tailNodes.insert({parent, false});
2035+
if (res.second) { // node not visited yet
2036+
searchForTailNodes(parent, tailNodes);
2037+
}
2038+
}
2039+
};
2040+
// collect the tail nodes
2041+
std::unordered_map<NodePtr, bool> tailNodesMap;
2042+
std::unordered_set<ov::element::Type_t> outputPrecisions;
2043+
// starting from output nodes
2044+
for (const auto& entry : outputNodesMap) {
2045+
const auto& output = entry.second;
2046+
if (output->getOriginalInputPrecisionAtPort(0) == inferPrec) {
2047+
continue;
2048+
}
2049+
outputPrecisions.insert(output->getOriginalInputPrecisionAtPort(0));
2050+
searchForTailNodes(output, tailNodesMap);
2051+
}
2052+
if (outputPrecisions.empty()) {
2053+
return;
2054+
}
2055+
2056+
const std::vector<Type> kStartTypes = {Type::Eltwise, Type::MVN};
2057+
const std::vector<Type> kPathTypes = {Type::Reshape, Type::Concatenation, Type::Split};
2058+
std::function<bool(const NodePtr&)> suitableForTailOptimization;
2059+
suitableForTailOptimization = [&](const NodePtr& node) -> bool {
2060+
NodePtr cur = node;
2061+
std::unordered_set<NodePtr> visited;
2062+
while (cur) {
2063+
if (visited.count(cur))
2064+
break;
2065+
visited.insert(cur);
2066+
2067+
size_t parentNum = cur->getParentEdges().size();
2068+
if (parentNum == 0)
2069+
return false;
2070+
2071+
bool allParentSuitable = true;
2072+
for (size_t i = 0; i < parentNum; ++i) {
2073+
auto parent = cur->getParentEdgeAt(i)->getParent();
2074+
if (!parent)
2075+
return false;
2076+
if ((std::find(kStartTypes.begin(), kStartTypes.end(), parent->getType()) != kStartTypes.end()) &&
2077+
tailNodesMap.count(parent)) {
2078+
continue;
2079+
}
2080+
if ((std::find(kPathTypes.begin(), kPathTypes.end(), parent->getType()) != kPathTypes.end()) &&
2081+
tailNodesMap.count(parent)) {
2082+
if (!suitableForTailOptimization(parent)) {
2083+
allParentSuitable = false;
2084+
break;
2085+
}
2086+
} else {
2087+
return false;
2088+
}
2089+
}
2090+
return allParentSuitable;
2091+
}
2092+
return false;
2093+
};
2094+
std::function<void(const NodePtr&, const ov::element::Type_t&)> resetTailPrecision;
2095+
resetTailPrecision = [&](const NodePtr& node, const ov::element::Type_t& outputPrecision) -> void {
2096+
// traverse upwards until encountering the first kStartTypes
2097+
for (size_t i = 0; i < node->getParentEdges().size(); ++i) {
2098+
auto parent = node->getParentEdgeAt(i)->getParent();
2099+
if (!parent)
2100+
continue;
2101+
OPENVINO_ASSERT(tailNodesMap.count(parent),
2102+
"resetTailPrecision: node ",
2103+
parent->getName(),
2104+
" with type ",
2105+
NameFromType(parent->getType()),
2106+
" is not in suitableForTailOptimization set");
2107+
if (tailNodesMap[parent]) {
2108+
continue;
2109+
}
2110+
tailNodesMap[parent] = true;
2111+
if (std::find(kStartTypes.begin(), kStartTypes.end(), parent->getType()) != kStartTypes.end()) {
2112+
// set the output precision of kStartTypes nodes to f32, input precision remains unchanged
2113+
for (size_t j = 0; j < parent->getOriginalOutputsNumber(); ++j) {
2114+
parent->setOriginalOutputPrecisionAtPort(j, outputPrecision);
2115+
}
2116+
} else {
2117+
// recursively process upwards
2118+
// set all input and output precisions of the current nodes to f32
2119+
for (size_t j = 0; j < parent->getOriginalInputsNumber(); ++j) {
2120+
parent->setOriginalInputPrecisionAtPort(j, outputPrecision);
2121+
}
2122+
for (size_t j = 0; j < parent->getOriginalOutputsNumber(); ++j) {
2123+
parent->setOriginalOutputPrecisionAtPort(j, outputPrecision);
2124+
}
2125+
resetTailPrecision(parent, outputPrecision);
2126+
}
2127+
}
2128+
};
2129+
2130+
std::function<void(const NodePtr&)> tailNodesPrecisionOptimizeMain;
2131+
tailNodesPrecisionOptimizeMain = [&](const NodePtr& node) -> void {
2132+
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2133+
const auto& parent = node->getParentEdgeAt(i)->getParent();
2134+
if (!tailNodesMap.count(parent)) {
2135+
continue;
2136+
}
2137+
if (one_of(parent->getType(), Type::Input, Type::Output, Type::MemoryInput, Type::MemoryOutput)) {
2138+
continue;
2139+
}
2140+
if (parent->keepOrigPrecision()) {
2141+
continue;
2142+
}
2143+
if ((parent->getType() == Type::Convert) && (parent->getOriginalInputPrecisionAtPort(0) == inferPrec) &&
2144+
outputPrecisions.count(parent->getOriginalOutputPrecisionAtPort(0))) {
2145+
bool suitableCase = false;
2146+
auto outprecision = parent->getOriginalOutputPrecisionAtPort(0);
2147+
for (size_t i = 0; i < parent->getParentEdges().size(); ++i) {
2148+
auto p = parent->getParentEdgeAt(i)->getParent();
2149+
if (!p)
2150+
continue;
2151+
if (std::find(kPathTypes.begin(), kPathTypes.end(), p->getType()) != kPathTypes.end()) {
2152+
if (suitableForTailOptimization(p)) {
2153+
suitableCase = true;
2154+
continue;
2155+
}
2156+
} else if (std::find(kStartTypes.begin(), kStartTypes.end(), p->getType()) !=
2157+
kStartTypes.end()) {
2158+
suitableCase = true;
2159+
continue;
2160+
}
2161+
}
2162+
if (suitableCase) {
2163+
// suitable case for tail optimization
2164+
resetTailPrecision(parent, outprecision);
2165+
DropNode(parent);
2166+
}
2167+
continue;
2168+
}
2169+
tailNodesPrecisionOptimizeMain(parent);
2170+
}
2171+
};
2172+
// tail optimization main process
2173+
for (const auto& entry : outputNodesMap) {
2174+
const auto& output = entry.second;
2175+
if (output->getOriginalInputPrecisionAtPort(0) == inferPrec) {
2176+
continue;
2177+
}
2178+
tailNodesPrecisionOptimizeMain(output);
2179+
}
2180+
return;
2181+
}
2182+
2183+
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
2184+
searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
2185+
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2186+
const auto& parent = node->getParentEdgeAt(i)->getParent();
2187+
if (inferPrec == ov::element::bf16) {
2188+
/* list of node types that must be forced to be executed in BF16 precision
20372189
* because of performance gains */
20382190
if (one_of(parent->getType(),
20392191
Type::Convolution, // conv nets
2040-
Type::Deconvolution, // deconv
20412192
Type::FullyConnected, // conv / bert nets
2193+
Type::RNNCell, // recurrent nets
2194+
Type::RNNSeq, // recurrent nets
20422195
Type::MatMul, // bert nets
2043-
Type::Pooling,
2044-
Type::MVN)) {
2196+
Type::ROIPooling, // object detection nets
2197+
Type::Interpolate, // super resolution nets
2198+
Type::PagedAttention, // page attention
2199+
Type::QKVProjection,
2200+
Type::LLMMLP)) {
20452201
continue; // stop at significant nodes
20462202
}
20472203
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/concat.hpp"
6+
#include "openvino/op/constant.hpp"
7+
#include "openvino/op/convert.hpp"
8+
#include "openvino/op/convolution.hpp"
9+
#include "openvino/op/multiply.hpp"
10+
#include "openvino/op/parameter.hpp"
11+
#include "openvino/op/result.hpp"
12+
#include "openvino/op/sigmoid.hpp"
13+
#include "shared_test_classes/base/ov_subgraph.hpp"
14+
#include "utils/cpu_test_utils.hpp"
15+
16+
using namespace CPUTestUtils;
17+
18+
namespace ov {
19+
namespace test {
20+
21+
class EnforceInferencePrecisionFP16TailTest : virtual public SubgraphBaseTest {
22+
public:
23+
static std::string getTestCaseName(testing::TestParamInfo<std::tuple<>> /*obj*/) {
24+
return "EnforceInferencePrecisionFP16TailTest";
25+
}
26+
27+
void SetUp() override {
28+
targetDevice = ov::test::utils::DEVICE_CPU;
29+
configuration = {{ov::hint::inference_precision.name(), ov::element::f16}};
30+
31+
std::vector<InputShape> inputShapes = {{{-1, 16, 16, 16}, {{1, 16, 16, 16}, {2, 16, 16, 16}}}};
32+
33+
init_input_shapes(inputShapes);
34+
35+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, inputDynamicShapes[0]);
36+
ov::Shape weights_shape = {16, 16, 1, 1}; // OIHW for 1x1 conv
37+
38+
auto weights = ov::op::v0::Constant::create(ov::element::f16, weights_shape, {1.0f});
39+
auto conv = std::make_shared<ov::op::v1::Convolution>(input,
40+
weights,
41+
ov::Strides{1, 1},
42+
ov::CoordinateDiff{0, 0},
43+
ov::CoordinateDiff{0, 0},
44+
ov::Strides{1, 1});
45+
conv->set_friendly_name("conv_node");
46+
auto mul_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1, 16, 16, 16}, {2.0f});
47+
auto mul = std::make_shared<ov::op::v1::Multiply>(conv, mul_const);
48+
49+
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(conv);
50+
51+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{mul, sigmoid}, 1);
52+
53+
auto convert_to_f32 = std::make_shared<ov::op::v0::Convert>(concat, ov::element::f32);
54+
55+
auto result = std::make_shared<ov::op::v0::Result>(convert_to_f32);
56+
57+
function = std::make_shared<ov::Model>(ov::ResultVector{result},
58+
ov::ParameterVector{input},
59+
"enforce_inference_precision_fp16_tail");
60+
}
61+
62+
void checkResults() {
63+
for (const auto& node : compiledModel.get_runtime_model()->get_ops()) {
64+
if (node->get_friendly_name() == "conv_node") {
65+
ASSERT_EQ(node->get_output_element_type(0), ElementType::f16);
66+
}
67+
}
68+
CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
69+
}
70+
};
71+
namespace {
72+
TEST_F(EnforceInferencePrecisionFP16TailTest, CompareWithRefs) {
73+
if (!ov::with_cpu_x86_avx512_core_amx_fp16())
74+
GTEST_SKIP() << "Skipping test, only fp16 runtime inference precision platform needed" << std::endl;
75+
run();
76+
serialize();
77+
checkResults();
78+
}
79+
} // namespace
80+
} // namespace test
81+
} // namespace ov

0 commit comments

Comments
 (0)