@@ -2004,20 +2004,18 @@ void Graph::EnforceInferencePrecision() {
20042004 CPU_DEBUG_CAP_ENABLE (EnforceInferPrcDebug inferPrecDebug);
20052005
20062006 const auto inferPrec = getConfig ().inferencePrecision ;
2007- if (one_of (inferPrec, element::f32 , element::dynamic, ov::element:: f16 , element::dynamic )) {
2007+ if (one_of (inferPrec, element::f32 , element::dynamic)) {
20082008 return ; // nothing to do, only precision reduction is currently allowed
20092009 }
2010- #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
20112010 if (inferPrec == ov::element::f16 ) {
2011+ #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
20122012 return ; // precision of configured by ov::pass::ConvertPrecision
2013- }
20142013#endif
2015- std::function<void (const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
2016- searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
2017- for (size_t i = 0 ; i < node->getParentEdges ().size (); i++) {
2018- const auto & parent = node->getParentEdgeAt (i)->getParent ();
2019- if (inferPrec == ov::element::bf16 ) {
2020- /* list of node types that must be forced to be executed in BF16 precision
2014+ std::function<void (const NodePtr&, std::unordered_map<NodePtr, bool >&)> searchForTailNodes;
2015+ searchForTailNodes = [&](const NodePtr& node, std::unordered_map<NodePtr, bool >& tailNodes) -> void {
2016+ for (size_t i = 0 ; i < node->getParentEdges ().size (); i++) {
2017+ const auto & parent = node->getParentEdgeAt (i)->getParent ();
2018+ /* list of node types that must be forced to be executed in F16 precision
20212019 * because of performance gains */
20222020 if (one_of (parent->getType (),
20232021 Type::Convolution, // conv nets
@@ -2029,19 +2027,181 @@ void Graph::EnforceInferencePrecision() {
20292027 Type::Interpolate, // super resolution nets
20302028 Type::PagedAttention, // page attention
20312029 Type::QKVProjection,
2032- Type::LLMMLP)) {
2030+ Type::LLMMLP,
2031+ Type::Pooling)) {
20332032 continue ; // stop at significant nodes
20342033 }
2035- } else if (inferPrec == ov::element::f16 ) {
2036- /* list of node types that must be forced to be executed in FP16 precision
2034+ const auto res = tailNodes.insert ({parent, false });
2035+ if (res.second ) { // node not visited yet
2036+ searchForTailNodes (parent, tailNodes);
2037+ }
2038+ }
2039+ };
2040+ // collect the tail nodes
2041+ std::unordered_map<NodePtr, bool > tailNodesMap;
2042+ std::unordered_set<ov::element::Type_t> outputPrecisions;
2043+ // starting from output nodes
2044+ for (const auto & entry : outputNodesMap) {
2045+ const auto & output = entry.second ;
2046+ if (output->getOriginalInputPrecisionAtPort (0 ) == inferPrec) {
2047+ continue ;
2048+ }
2049+ outputPrecisions.insert (output->getOriginalInputPrecisionAtPort (0 ));
2050+ searchForTailNodes (output, tailNodesMap);
2051+ }
2052+ if (outputPrecisions.empty ()) {
2053+ return ;
2054+ }
2055+
2056+ const std::vector<Type> kStartTypes = {Type::Eltwise, Type::MVN};
2057+ const std::vector<Type> kPathTypes = {Type::Reshape, Type::Concatenation, Type::Split};
2058+ std::function<bool (const NodePtr&)> suitableForTailOptimization;
2059+ suitableForTailOptimization = [&](const NodePtr& node) -> bool {
2060+ const NodePtr& cur = node;
2061+ std::unordered_set<NodePtr> visited;
2062+ while (cur) {
2063+ if (visited.count (cur)) {
2064+ break ;
2065+ }
2066+ visited.insert (cur);
2067+
2068+ size_t parentNum = cur->getParentEdges ().size ();
2069+ if (parentNum == 0 ) {
2070+ return false ;
2071+ }
2072+ bool allParentSuitable = true ;
2073+ for (size_t i = 0 ; i < parentNum; ++i) {
2074+ auto parent = cur->getParentEdgeAt (i)->getParent ();
2075+ if (!parent) {
2076+ return false ;
2077+ }
2078+ if ((std::find (kStartTypes .begin (), kStartTypes .end (), parent->getType ()) != kStartTypes .end ()) &&
2079+ tailNodesMap.count (parent) != 0u ) {
2080+ continue ;
2081+ }
2082+ if ((std::find (kPathTypes .begin (), kPathTypes .end (), parent->getType ()) != kPathTypes .end ()) &&
2083+ tailNodesMap.count (parent) != 0u ) {
2084+ if (!suitableForTailOptimization (parent)) {
2085+ allParentSuitable = false ;
2086+ break ;
2087+ }
2088+ } else {
2089+ return false ;
2090+ }
2091+ }
2092+ return allParentSuitable;
2093+ }
2094+ return false ;
2095+ };
2096+ std::function<void (const NodePtr&, const ov::element::Type_t&)> resetTailPrecision;
2097+ resetTailPrecision = [&](const NodePtr& node, const ov::element::Type_t& outputPrecision) -> void {
2098+ // traverse upwards until encountering the first kStartTypes
2099+ for (size_t i = 0 ; i < node->getParentEdges ().size (); ++i) {
2100+ auto parent = node->getParentEdgeAt (i)->getParent ();
2101+ if (!parent) {
2102+ continue ;
2103+ }
2104+ OPENVINO_ASSERT (tailNodesMap.count (parent),
2105+ " resetTailPrecision: node " ,
2106+ parent->getName (),
2107+ " with type " ,
2108+ NameFromType (parent->getType ()),
2109+ " is not in suitableForTailOptimization set" );
2110+ if (tailNodesMap[parent]) {
2111+ continue ;
2112+ }
2113+ tailNodesMap[parent] = true ;
2114+ if (std::find (kStartTypes .begin (), kStartTypes .end (), parent->getType ()) != kStartTypes .end ()) {
2115+ // set the output precision of kStartTypes nodes to f32, input precision remains unchanged
2116+ for (size_t j = 0 ; j < parent->getOriginalOutputsNumber (); ++j) {
2117+ parent->setOriginalOutputPrecisionAtPort (j, outputPrecision);
2118+ }
2119+ } else {
2120+ // recursively process upwards
2121+ // set all input and output precisions of the current nodes to f32
2122+ for (size_t j = 0 ; j < parent->getOriginalInputsNumber (); ++j) {
2123+ parent->setOriginalInputPrecisionAtPort (j, outputPrecision);
2124+ }
2125+ for (size_t j = 0 ; j < parent->getOriginalOutputsNumber (); ++j) {
2126+ parent->setOriginalOutputPrecisionAtPort (j, outputPrecision);
2127+ }
2128+ resetTailPrecision (parent, outputPrecision);
2129+ }
2130+ }
2131+ };
2132+
2133+ std::function<void (const NodePtr&)> tailNodesPrecisionOptimizeMain;
2134+ tailNodesPrecisionOptimizeMain = [&](const NodePtr& node) -> void {
2135+ for (size_t i = 0 ; i < node->getParentEdges ().size (); i++) {
2136+ const auto & parent = node->getParentEdgeAt (i)->getParent ();
2137+ if (!tailNodesMap.count (parent)) {
2138+ continue ;
2139+ }
2140+ if (one_of (parent->getType (), Type::Input, Type::Output, Type::MemoryInput, Type::MemoryOutput)) {
2141+ continue ;
2142+ }
2143+ if (parent->keepOrigPrecision ()) {
2144+ continue ;
2145+ }
2146+ if ((parent->getType () == Type::Convert) && (parent->getOriginalInputPrecisionAtPort (0 ) == inferPrec) &&
2147+ outputPrecisions.count (parent->getOriginalOutputPrecisionAtPort (0 )) != 0u ) {
2148+ bool suitableCase = false ;
2149+ auto outprecision = parent->getOriginalOutputPrecisionAtPort (0 );
2150+ for (size_t i = 0 ; i < parent->getParentEdges ().size (); ++i) {
2151+ auto p = parent->getParentEdgeAt (i)->getParent ();
2152+ if (!p) {
2153+ continue ;
2154+ }
2155+ if (std::find (kPathTypes .begin (), kPathTypes .end (), p->getType ()) != kPathTypes .end ()) {
2156+ if (suitableForTailOptimization (p)) {
2157+ suitableCase = true ;
2158+ continue ;
2159+ }
2160+ } else if (std::find (kStartTypes .begin (), kStartTypes .end (), p->getType ()) !=
2161+ kStartTypes .end ()) {
2162+ suitableCase = true ;
2163+ continue ;
2164+ }
2165+ }
2166+ if (suitableCase) {
2167+ // suitable case for tail optimization
2168+ resetTailPrecision (parent, outprecision);
2169+ DropNode (parent);
2170+ }
2171+ continue ;
2172+ }
2173+ tailNodesPrecisionOptimizeMain (parent);
2174+ }
2175+ };
2176+ // tail optimization main process
2177+ for (const auto & entry : outputNodesMap) {
2178+ const auto & output = entry.second ;
2179+ if (output->getOriginalInputPrecisionAtPort (0 ) == inferPrec) {
2180+ continue ;
2181+ }
2182+ tailNodesPrecisionOptimizeMain (output);
2183+ }
2184+ return ;
2185+ }
2186+
2187+ std::function<void (const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
2188+ searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
2189+ for (size_t i = 0 ; i < node->getParentEdges ().size (); i++) {
2190+ const auto & parent = node->getParentEdgeAt (i)->getParent ();
2191+ if (inferPrec == ov::element::bf16 ) {
2192+ /* list of node types that must be forced to be executed in BF16 precision
20372193 * because of performance gains */
20382194 if (one_of (parent->getType (),
20392195 Type::Convolution, // conv nets
2040- Type::Deconvolution, // deconv
20412196 Type::FullyConnected, // conv / bert nets
2197+ Type::RNNCell, // recurrent nets
2198+ Type::RNNSeq, // recurrent nets
20422199 Type::MatMul, // bert nets
2043- Type::Pooling,
2044- Type::MVN)) {
2200+ Type::ROIPooling, // object detection nets
2201+ Type::Interpolate, // super resolution nets
2202+ Type::PagedAttention, // page attention
2203+ Type::QKVProjection,
2204+ Type::LLMMLP)) {
20452205 continue ; // stop at significant nodes
20462206 }
20472207 }
0 commit comments