Skip to content

Commit 73ca2e8

Browse files
committed
update opset version
1 parent f64cd94 commit 73ca2e8

File tree

11 files changed

+34
-34
lines changed

11 files changed

+34
-34
lines changed

onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ Status UpStreamReshapeGraphTransformer::RemoveOriginalReshapeNode(
225225

226226
std::optional<ReshapeInfo> UpStreamReshapeGraphTransformer::IsSupportedForUpstream(
227227
Graph& graph, Node& node, const logging::Logger& logger) const {
228-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {1, 5, 13, 14}, kOnnxDomain)) {
228+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {1, 5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain)) {
229229
return std::nullopt;
230230
}
231231

onnxruntime/core/optimizer/fast_gelu_fusion.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
151151
if (p_cast1_node != nullptr) {
152152
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
153153
// this is fused Cast node, so expect 2 output edges
154-
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast1_node, "Cast", {9, 13, 19}) &&
154+
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast1_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
155155
CheckNode(graph, cast1_node, pow1_node.GetExecutionProviderType(), false)) ||
156156
cast1_node.GetOutputEdgesCount() != 2) {
157157
return matchResult;
@@ -262,7 +262,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
262262
if (p_cast3_node == nullptr) continue;
263263

264264
Node& cast3_node = *graph.GetNode(p_cast3_node->Index());
265-
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast3_node, "Cast", {9, 13, 19}) &&
265+
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast3_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
266266
CheckNode(graph, cast3_node, node.GetExecutionProviderType(), true))) {
267267
continue;
268268
}

onnxruntime/core/optimizer/gemm_transpose_fusion.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
104104

105105
// Fusion can be applied if there is a transpose at either of the inputs
106106
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
107-
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13, 21}) &&
107+
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13, 21, 23, 24, 25}) &&
108108
!graph.NodeProducesGraphOutput(*node_it) &&
109109
// Make sure the two nodes do not span execution providers.
110110
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
@@ -128,7 +128,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
128128

129129
const auto next_node_it = node.OutputNodesBegin();
130130
if (next_node_it != node.OutputNodesEnd() &&
131-
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node_it, "Transpose", {1, 13, 21}) &&
131+
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node_it, "Transpose", {1, 13, 21, 23, 24, 25}) &&
132132
next_node_it->GetInputEdgesCount() == 1 &&
133133
// Make sure the two nodes do not span execution providers.
134134
next_node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {

onnxruntime/core/optimizer/isinf_reducesum_fusion.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
3333

3434
ORT_RETURN_IF_ERROR(Recurse(isinf_node, modified, graph_level, logger));
3535

36-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10}) ||
36+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10, 20}) ||
3737
isinf_node.GetOutputEdgesCount() != 1 ||
3838
graph.NodeProducesGraphOutput(isinf_node)) {
3939
continue;
@@ -45,7 +45,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
4545
// This Cast can be skipped as we are replacing the subgraph with IsAllFinite, which supports FP16
4646
auto cast1_node_iter = isinf_node.InputNodesBegin();
4747
if (cast1_node_iter != isinf_node.InputNodesEnd() &&
48-
graph_utils::IsSupportedOptypeVersionAndDomain(*cast1_node_iter, "Cast", {9, 13, 19}) &&
48+
graph_utils::IsSupportedOptypeVersionAndDomain(*cast1_node_iter, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
4949
cast1_node_iter->GetOutputEdgesCount() == 1) {
5050
// check input type of cast node
5151
Node& cast1_node = *graph.GetNode(cast1_node_iter->Index());
@@ -65,7 +65,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
6565
}
6666

6767
Node& cast2_node = *graph.GetNode(cast2_node_itr->Index());
68-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13, 19}) ||
68+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) ||
6969
cast2_node.GetOutputEdgesCount() != 1 ||
7070
graph.NodeProducesGraphOutput(cast2_node)) {
7171
continue;

onnxruntime/core/optimizer/layer_norm_fusion.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
241241
if (p_reduce_mean_input_node) {
242242
Node& reduce_mean_input_node = *graph.GetNode(p_reduce_mean_input_node->Index());
243243
// If input to the 1st ReduceMean is a Cast, and the Cast has same consumer count as subCnt + 1
244-
if (graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_input_node, "Cast", {9, 13, 19}) &&
244+
if (graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
245245
reduce_mean_input_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
246246
optimizer_utils::CheckOutputEdges(graph, reduce_mean_input_node, static_cast<size_t>(subCnt) + 1)) {
247247
nodes_to_remove.insert(nodes_to_remove.begin(), reduce_mean_input_node);
@@ -254,7 +254,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
254254
const Node* p_cast1 = nullptr;
255255
if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) {
256256
Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index());
257-
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19}) &&
257+
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
258258
cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
259259
optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) {
260260
p_cast1 = &cast_node;
@@ -353,7 +353,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
353353
const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast");
354354
if (p_cast2 != nullptr && p_cast2 != p_cast1) {
355355
Node& cast_node = *graph.GetNode(p_cast2->Index());
356-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19}) ||
356+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) ||
357357
cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
358358
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) {
359359
continue;
@@ -371,7 +371,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
371371
// can be removed. This is one possible place a Cast Op can exist, that is between Div and Mul nodes.
372372
// div --> mul or div --> cast --> mul
373373
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
374-
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
374+
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
375375
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
376376
nodes_to_remove.push_back(*next_node);
377377
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());
@@ -637,7 +637,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
637637
if (is_gpu_ep && p_pow_input_node) {
638638
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
639639
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
640-
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13, 19}) &&
640+
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
641641
pow_input_node.GetExecutionProviderType() == pow_node.GetExecutionProviderType() &&
642642
optimizer_utils::CheckOutputEdges(graph, pow_input_node, 2)) {
643643
nodes_to_remove.insert(nodes_to_remove.begin(), pow_input_node);
@@ -647,7 +647,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
647647

648648
// div --> mul or div --> cast --> mul
649649
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
650-
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
650+
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
651651
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
652652
if (!is_gpu_ep) continue;
653653
nodes_to_remove.push_back(*next_node);

onnxruntime/core/optimizer/matmul_bn_fusion.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ namespace onnxruntime {
1010

1111
namespace {
1212
const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
13-
{"Reshape", {1, 5, 13, 14, 19}},
14-
{"Transpose", {1, 13}}};
13+
{"Reshape", {1, 5, 13, 14, 19, 21, 23, 24, 25}},
14+
{"Transpose", {1, 13, 21, 23, 24, 25}}};
1515
const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}};
1616
} // namespace
1717

@@ -244,4 +244,4 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
244244
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
245245
return Status::OK();
246246
}
247-
} // namespace onnxruntime
247+
} // namespace onnxruntime

onnxruntime/core/optimizer/nchwc_transformer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) {
11491149
}
11501150

11511151
void NchwcTransformerImpl::Transform(Node& node) {
1152-
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21})) {
1152+
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21, 23, 24, 25})) {
11531153
TrackTransposeFromNhwc(node);
11541154
}
11551155

@@ -1178,7 +1178,7 @@ void NchwcTransformerImpl::Transform(Node& node) {
11781178
TransformActivation(node);
11791179
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) {
11801180
TransformBatchNormalization(node);
1181-
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21})) {
1181+
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21, 23, 24, 25})) {
11821182
TransformTransposeToNhwc(node);
11831183
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
11841184
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13, 18, 19})) {

onnxruntime/core/optimizer/not_where_fusion.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Condition -> Where ->
3939
v0----|
4040
*/
4141
bool NotWhereFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
42-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Where", {9})) {
42+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Where", {9, 16})) {
4343
return false;
4444
}
4545

@@ -54,7 +54,7 @@ bool NotWhereFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
5454
if (p_not_node->GetOutputEdgesCount() > 1) {
5555
// all consumers of not must be where
5656
for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) {
57-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*it, "Where", {9})) {
57+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*it, "Where", {9, 16})) {
5858
return false;
5959
}
6060
}

onnxruntime/core/optimizer/pad_fusion.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_v
9090
*/
9191
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
9292
// if Pad has input axis, don't fuse it.
93-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
93+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19, 21, 23, 24, 25}) ||
9494
node.GetOutputEdgesCount() != 1 ||
9595
node.InputDefs().size() > 3) {
9696
return false;
@@ -130,7 +130,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
130130
}
131131

132132
const Node& child_node = *node.OutputNodesBegin();
133-
if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13})) {
133+
if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13, 19, 21, 23, 24, 25})) {
134134
if (child_node.GetOutputEdgesCount() != 1) {
135135
return false;
136136
}

onnxruntime/core/optimizer/pre_shape_node_elimination.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ bool PreShapeNodeElimination::SatisfyCondition(const Graph& graph, const Node& n
4848

4949
for (const Node* next_node : output_nodes) {
5050
// Check if the next node is not of type "Shape"
51-
if (!next_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Shape", {13, 15, 19}, kOnnxDomain)) {
51+
if (!next_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Shape", {13, 15, 19, 21, 23, 24, 25}, kOnnxDomain)) {
5252
return false;
5353
}
5454
}

0 commit comments

Comments
 (0)