Skip to content

Commit 42da14c

Browse files
committed
extend model tail nodes optimization method to distinguish support inplace and non-inplace cases
1 parent cc5b819 commit 42da14c

File tree

7 files changed

+554
-268
lines changed

7 files changed

+554
-268
lines changed

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 1 addition & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,185 +2004,9 @@ void Graph::EnforceInferencePrecision() {
20042004
CPU_DEBUG_CAP_ENABLE(EnforceInferPrcDebug inferPrecDebug);
20052005

20062006
const auto inferPrec = getConfig().inferencePrecision;
2007-
if (one_of(inferPrec, element::f32, element::dynamic)) {
2007+
if (one_of(inferPrec, element::f32, element::f16, element::dynamic)) {
20082008
return; // nothing to do, only precision reduction is currently allowed
20092009
}
2010-
if (inferPrec == ov::element::f16) {
2011-
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
2012-
return; // precision of configured by ov::pass::ConvertPrecision
2013-
#endif
2014-
std::function<void(const NodePtr&, std::unordered_map<NodePtr, bool>&)> searchForTailNodes;
2015-
searchForTailNodes = [&](const NodePtr& node, std::unordered_map<NodePtr, bool>& tailNodes) -> void {
2016-
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2017-
const auto& parent = node->getParentEdgeAt(i)->getParent();
2018-
/* list of node types that must be forced to be executed in F16 precision
2019-
* because of performance gains */
2020-
if (one_of(parent->getType(),
2021-
Type::Convolution, // conv nets
2022-
Type::FullyConnected, // conv / bert nets
2023-
Type::RNNCell, // recurrent nets
2024-
Type::RNNSeq, // recurrent nets
2025-
Type::MatMul, // bert nets
2026-
Type::ROIPooling, // object detection nets
2027-
Type::Interpolate, // super resolution nets
2028-
Type::PagedAttention, // page attention
2029-
Type::QKVProjection,
2030-
Type::LLMMLP,
2031-
Type::Pooling)) {
2032-
continue; // stop at significant nodes
2033-
}
2034-
const auto res = tailNodes.insert({parent, false});
2035-
if (res.second) { // node not visited yet
2036-
searchForTailNodes(parent, tailNodes);
2037-
}
2038-
}
2039-
};
2040-
// collect the tail nodes
2041-
std::unordered_map<NodePtr, bool> tailNodesMap;
2042-
std::unordered_set<ov::element::Type_t> outputPrecisions;
2043-
// starting from output nodes
2044-
for (const auto& entry : outputNodesMap) {
2045-
const auto& output = entry.second;
2046-
if (output->getOriginalInputPrecisionAtPort(0) == inferPrec) {
2047-
continue;
2048-
}
2049-
outputPrecisions.insert(output->getOriginalInputPrecisionAtPort(0));
2050-
searchForTailNodes(output, tailNodesMap);
2051-
}
2052-
if (outputPrecisions.empty()) {
2053-
return;
2054-
}
2055-
2056-
const std::vector<Type> kStartTypes = {Type::Eltwise, Type::MVN};
2057-
const std::vector<Type> kPathTypes = {Type::Reshape, Type::Concatenation, Type::Split};
2058-
std::function<bool(const NodePtr&)> suitableForTailOptimization;
2059-
suitableForTailOptimization = [&](const NodePtr& node) -> bool {
2060-
const NodePtr& cur = node;
2061-
std::unordered_set<NodePtr> visited;
2062-
while (cur) {
2063-
if (visited.count(cur)) {
2064-
break;
2065-
}
2066-
visited.insert(cur);
2067-
2068-
size_t parentNum = cur->getParentEdges().size();
2069-
if (parentNum == 0) {
2070-
return false;
2071-
}
2072-
bool allParentSuitable = true;
2073-
for (size_t i = 0; i < parentNum; ++i) {
2074-
auto parent = cur->getParentEdgeAt(i)->getParent();
2075-
if (!parent) {
2076-
return false;
2077-
}
2078-
if ((std::find(kStartTypes.begin(), kStartTypes.end(), parent->getType()) != kStartTypes.end()) &&
2079-
tailNodesMap.count(parent) != 0u) {
2080-
continue;
2081-
}
2082-
if ((std::find(kPathTypes.begin(), kPathTypes.end(), parent->getType()) != kPathTypes.end()) &&
2083-
tailNodesMap.count(parent) != 0u) {
2084-
if (!suitableForTailOptimization(parent)) {
2085-
allParentSuitable = false;
2086-
break;
2087-
}
2088-
} else {
2089-
return false;
2090-
}
2091-
}
2092-
return allParentSuitable;
2093-
}
2094-
return false;
2095-
};
2096-
std::function<void(const NodePtr&, const ov::element::Type_t&)> resetTailPrecision;
2097-
resetTailPrecision = [&](const NodePtr& node, const ov::element::Type_t& outputPrecision) -> void {
2098-
// traverse upwards until encountering the first kStartTypes
2099-
for (size_t i = 0; i < node->getParentEdges().size(); ++i) {
2100-
auto parent = node->getParentEdgeAt(i)->getParent();
2101-
if (!parent) {
2102-
continue;
2103-
}
2104-
OPENVINO_ASSERT(tailNodesMap.count(parent),
2105-
"resetTailPrecision: node ",
2106-
parent->getName(),
2107-
" with type ",
2108-
NameFromType(parent->getType()),
2109-
" is not in suitableForTailOptimization set");
2110-
if (tailNodesMap[parent]) {
2111-
continue;
2112-
}
2113-
tailNodesMap[parent] = true;
2114-
if (std::find(kStartTypes.begin(), kStartTypes.end(), parent->getType()) != kStartTypes.end()) {
2115-
// set the output precision of kStartTypes nodes to f32, input precision remains unchanged
2116-
for (size_t j = 0; j < parent->getOriginalOutputsNumber(); ++j) {
2117-
parent->setOriginalOutputPrecisionAtPort(j, outputPrecision);
2118-
}
2119-
} else {
2120-
// recursively process upwards
2121-
// set all input and output precisions of the current nodes to f32
2122-
for (size_t j = 0; j < parent->getOriginalInputsNumber(); ++j) {
2123-
parent->setOriginalInputPrecisionAtPort(j, outputPrecision);
2124-
}
2125-
for (size_t j = 0; j < parent->getOriginalOutputsNumber(); ++j) {
2126-
parent->setOriginalOutputPrecisionAtPort(j, outputPrecision);
2127-
}
2128-
resetTailPrecision(parent, outputPrecision);
2129-
}
2130-
}
2131-
};
2132-
2133-
std::function<void(const NodePtr&)> tailNodesPrecisionOptimizeMain;
2134-
tailNodesPrecisionOptimizeMain = [&](const NodePtr& node) -> void {
2135-
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
2136-
const auto& parent = node->getParentEdgeAt(i)->getParent();
2137-
if (!tailNodesMap.count(parent)) {
2138-
continue;
2139-
}
2140-
if (one_of(parent->getType(), Type::Input, Type::Output, Type::MemoryInput, Type::MemoryOutput)) {
2141-
continue;
2142-
}
2143-
if (parent->keepOrigPrecision()) {
2144-
continue;
2145-
}
2146-
if ((parent->getType() == Type::Convert) && (parent->getOriginalInputPrecisionAtPort(0) == inferPrec) &&
2147-
outputPrecisions.count(parent->getOriginalOutputPrecisionAtPort(0)) != 0u) {
2148-
bool suitableCase = false;
2149-
auto outprecision = parent->getOriginalOutputPrecisionAtPort(0);
2150-
for (size_t i = 0; i < parent->getParentEdges().size(); ++i) {
2151-
auto p = parent->getParentEdgeAt(i)->getParent();
2152-
if (!p) {
2153-
continue;
2154-
}
2155-
if (std::find(kPathTypes.begin(), kPathTypes.end(), p->getType()) != kPathTypes.end()) {
2156-
if (suitableForTailOptimization(p)) {
2157-
suitableCase = true;
2158-
continue;
2159-
}
2160-
} else if (std::find(kStartTypes.begin(), kStartTypes.end(), p->getType()) !=
2161-
kStartTypes.end()) {
2162-
suitableCase = true;
2163-
continue;
2164-
}
2165-
}
2166-
if (suitableCase) {
2167-
// suitable case for tail optimization
2168-
resetTailPrecision(parent, outprecision);
2169-
DropNode(parent);
2170-
}
2171-
continue;
2172-
}
2173-
tailNodesPrecisionOptimizeMain(parent);
2174-
}
2175-
};
2176-
// tail optimization main process
2177-
for (const auto& entry : outputNodesMap) {
2178-
const auto& output = entry.second;
2179-
if (output->getOriginalInputPrecisionAtPort(0) == inferPrec) {
2180-
continue;
2181-
}
2182-
tailNodesPrecisionOptimizeMain(output);
2183-
}
2184-
return;
2185-
}
21862010

21872011
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
21882012
searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {

0 commit comments

Comments
 (0)