@@ -241,7 +241,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
241241 if (p_reduce_mean_input_node) {
242242 Node& reduce_mean_input_node = *graph.GetNode (p_reduce_mean_input_node->Index ());
243243 // If input to the 1st ReduceMean is a Cast, and the Cast has same consumer count as subCnt + 1
244- if (graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean_input_node, " Cast" , {9 , 13 , 19 }) &&
244+ if (graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean_input_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) &&
245245 reduce_mean_input_node.GetExecutionProviderType () == reduce_mean_node.GetExecutionProviderType () &&
246246 optimizer_utils::CheckOutputEdges (graph, reduce_mean_input_node, static_cast <size_t >(subCnt) + 1 )) {
247247 nodes_to_remove.insert (nodes_to_remove.begin (), reduce_mean_input_node);
@@ -254,7 +254,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
254254 const Node* p_cast1 = nullptr ;
255255 if (!p_sub_node_dup && sub_node.GetOutputEdgesCount () == 1 ) {
256256 Node& cast_node = *graph.GetNode (sub_node.OutputNodesBegin ()->Index ());
257- if (graph_utils::IsSupportedOptypeVersionAndDomain (cast_node, " Cast" , {9 , 13 , 19 }) &&
257+ if (graph_utils::IsSupportedOptypeVersionAndDomain (cast_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) &&
258258 cast_node.GetExecutionProviderType () == reduce_mean_node.GetExecutionProviderType () &&
259259 optimizer_utils::CheckOutputEdges (graph, cast_node, 2u ) && IsSupportedDataType (cast_node)) {
260260 p_cast1 = &cast_node;
@@ -353,7 +353,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
353353 const Node* p_cast2 = graph_utils::FirstParentByType (pow_node, " Cast" );
354354 if (p_cast2 != nullptr && p_cast2 != p_cast1) {
355355 Node& cast_node = *graph.GetNode (p_cast2->Index ());
356- if (!graph_utils::IsSupportedOptypeVersionAndDomain (cast_node, " Cast" , {9 , 13 , 19 }) ||
356+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (cast_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) ||
357357 cast_node.GetExecutionProviderType () != reduce_mean_node.GetExecutionProviderType () ||
358358 !optimizer_utils::CheckOutputEdges (graph, cast_node, 1 )) {
359359 continue ;
@@ -371,7 +371,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
371371 // can be removed. This is one possible place a Cast Op can exist, that is between Div and Mul nodes.
372372 // div --> mul or div --> cast --> mul
373373 Node* next_node = graph.GetNode (div_node.OutputNodesBegin ()->Index ());
374- if (graph_utils::IsSupportedOptypeVersionAndDomain (*next_node, " Cast" , {9 , 13 , 19 }) &&
374+ if (graph_utils::IsSupportedOptypeVersionAndDomain (*next_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) &&
375375 optimizer_utils::CheckOutputEdges (graph, *next_node, 1 )) {
376376 nodes_to_remove.push_back (*next_node);
377377 next_node = graph.GetNode (next_node->OutputNodesBegin ()->Index ());
@@ -637,7 +637,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
637637 if (is_gpu_ep && p_pow_input_node) {
638638 Node& pow_input_node = *graph.GetNode (p_pow_input_node->Index ());
639639 // If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
640- if (graph_utils::IsSupportedOptypeVersionAndDomain (pow_input_node, " Cast" , {9 , 13 , 19 }) &&
640+ if (graph_utils::IsSupportedOptypeVersionAndDomain (pow_input_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) &&
641641 pow_input_node.GetExecutionProviderType () == pow_node.GetExecutionProviderType () &&
642642 optimizer_utils::CheckOutputEdges (graph, pow_input_node, 2 )) {
643643 nodes_to_remove.insert (nodes_to_remove.begin (), pow_input_node);
@@ -647,7 +647,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
647647
648648 // div --> mul or div --> cast --> mul
649649 Node* next_node = graph.GetNode (div_node.OutputNodesBegin ()->Index ());
650- if (graph_utils::IsSupportedOptypeVersionAndDomain (*next_node, " Cast" , {9 , 13 , 19 }) &&
650+ if (graph_utils::IsSupportedOptypeVersionAndDomain (*next_node, " Cast" , {9 , 13 , 19 , 21 , 23 , 24 , 25 }) &&
651651 optimizer_utils::CheckOutputEdges (graph, *next_node, 1 )) {
652652 if (!is_gpu_ep) continue ;
653653 nodes_to_remove.push_back (*next_node);
0 commit comments