Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/attention_fusion_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::
return false;
}

if (graph_utils::IsSupportedOptypeVersionAndDomain(*k_concat, "Transpose", {1, 13}, kOnnxDomain)) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(*k_concat, "Transpose", {1, 13, 21}, kOnnxDomain)) {
transpose_optimized_pattern = true;
DEBUG_LOG("Using transpose optimized pattern");
opt_k_transpose = k_concat;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/bias_dropout_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_leve
}

const Node& next_node = (*next_node_itr);
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12, 13}, kOnnxDomain) &&
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12, 13, 22}, kOnnxDomain) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BitmaskDropout", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Status UpStreamReshapeGraphTransformer::RemoveOriginalReshapeNode(

std::optional<ReshapeInfo> UpStreamReshapeGraphTransformer::IsSupportedForUpstream(
Graph& graph, Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {1, 5, 13, 14}, kOnnxDomain)) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {1, 5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain)) {
return std::nullopt;
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ConvActivationSelector : public NodeSelector {
return std::nullopt;
} else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) {
if (!is_supported_non_cuda_ep_activation(*next_node) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) {
!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6, 22})) {
return std::nullopt;
}
} else {
Expand Down Expand Up @@ -212,7 +212,7 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) {
const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain);
auto selector = std::make_unique<selectors::ConvActivationSelector>();

registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {1, 11}}, {msDomainConv, {1}}},
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11, 22}}, {msInternalNHWCDomainConv, {1, 11, 22}}, {msDomainConv, {1}}},
std::move(selector), std::move(action));
#else
registry.RegisterAction(name, std::move(action));
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_add_act_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class ConvAddActivationSelector : public NodeSelector {
return true;
}

if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "HardSigmoid", {6})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "HardSigmoid", {6, 22})) {
return true;
}
return false;
Expand Down Expand Up @@ -288,7 +288,7 @@ void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) {
auto action = std::make_unique<actions::FuseConvAddActivationAction>();
auto selector = std::make_unique<selectors::ConvAddActivationSelector>();
std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain);
registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}},
registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11, 22}}, {msDomainNhwcFusedConv, {1, 11, 22}}},
std::move(selector), std::move(action));
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/conv_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
}

bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11, 22}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/conv_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
}

bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11, 22}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/conv_mul_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
}

bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11, 22}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/dropout_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Status EliminateDropout::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule
bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
// We currently support elimination for Dropout operator v1, v6, v7, v10 and v12.
// REVIEW(mzs): v10 implementation does not exist.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12, 13})) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {1, 6, 7, 10, 12, 13, 22})) {
return false;
}

Expand All @@ -32,7 +32,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, co
// 2. ratio input is not a graph input, so it cannot be overridden

// support opset 12 and above for ort training
if (graph_utils::MatchesOpSinceVersion(node, {12, 13}) && node.InputDefs().size() > 1) {
if (graph_utils::MatchesOpSinceVersion(node, {12, 13, 22}) && node.InputDefs().size() > 1) {
if (graph_utils::IsGraphInput(graph, node.InputDefs()[1])) {
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/fast_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node,
if (p_cast1_node != nullptr) {
Node& cast1_node = *graph.GetNode(p_cast1_node->Index());
// this is fused Cast node, so expect 2 output edges
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast1_node, "Cast", {9, 13, 19}) &&
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast1_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
CheckNode(graph, cast1_node, pow1_node.GetExecutionProviderType(), false)) ||
cast1_node.GetOutputEdgesCount() != 2) {
return matchResult;
Expand Down Expand Up @@ -262,7 +262,7 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (p_cast3_node == nullptr) continue;

Node& cast3_node = *graph.GetNode(p_cast3_node->Index());
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast3_node, "Cast", {9, 13, 19}) &&
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(cast3_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
CheckNode(graph, cast3_node, node.GetExecutionProviderType(), true))) {
continue;
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/optimizer/gemm_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, const std::string& op_t
// If the op has multiple versions, here we require it must have a single implementation that can work across all the
// versions. Because in the fusion, we discarded the op version information.
bool IsFusableActivation(const Node& node) {
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6}, kOnnxDomain) ||
return IsSupportedOptypeVersionAndDomain(node, "Elu", {6, 22}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 22}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", {6, 16}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Selu", {6, 22}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softsign", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softplus", {1, 22}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Softsign", {1, 22}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}, kOnnxDomain) ||
#ifndef DISABLE_CONTRIB_OPS
IsSupportedOptypeVersionAndDomain(node, "ScaledTanh", {1}, kOnnxDomain) ||
IsSupportedOptypeVersionAndDomain(node, "ParametricSoftplus", {1}, kOnnxDomain) ||
#endif
IsSupportedOptypeVersionAndDomain(node, "ThresholdedRelu", {1, 10}, kOnnxDomain);
IsSupportedOptypeVersionAndDomain(node, "ThresholdedRelu", {1, 10, 22}, kOnnxDomain);
}
} // namespace

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/gemm_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,

// Fusion can be applied if there is a transpose at either of the inputs
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13, 21, 23, 24, 25}) &&
!graph.NodeProducesGraphOutput(*node_it) &&
// Make sure the two nodes do not span execution providers.
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
Expand All @@ -128,7 +128,7 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,

const auto next_node_it = node.OutputNodesBegin();
if (next_node_it != node.OutputNodesEnd() &&
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node_it, "Transpose", {1, 13}) &&
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node_it, "Transpose", {1, 13, 21, 23, 24, 25}) &&
next_node_it->GetInputEdgesCount() == 1 &&
// Make sure the two nodes do not span execution providers.
next_node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/optimizer/isinf_reducesum_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l

ORT_RETURN_IF_ERROR(Recurse(isinf_node, modified, graph_level, logger));

if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(isinf_node, "IsInf", {10, 20}) ||
isinf_node.GetOutputEdgesCount() != 1 ||
graph.NodeProducesGraphOutput(isinf_node)) {
continue;
Expand All @@ -45,7 +45,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
// This Cast can be skipped as we are replacing the subgraph with IsAllFinite, which supports FP16
auto cast1_node_iter = isinf_node.InputNodesBegin();
if (cast1_node_iter != isinf_node.InputNodesEnd() &&
graph_utils::IsSupportedOptypeVersionAndDomain(*cast1_node_iter, "Cast", {9, 13, 19}) &&
graph_utils::IsSupportedOptypeVersionAndDomain(*cast1_node_iter, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
cast1_node_iter->GetOutputEdgesCount() == 1) {
// check input type of cast node
Node& cast1_node = *graph.GetNode(cast1_node_iter->Index());
Expand All @@ -65,7 +65,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
}

Node& cast2_node = *graph.GetNode(cast2_node_itr->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13, 19}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast2_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) ||
cast2_node.GetOutputEdgesCount() != 1 ||
graph.NodeProducesGraphOutput(cast2_node)) {
continue;
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (p_reduce_mean_input_node) {
Node& reduce_mean_input_node = *graph.GetNode(p_reduce_mean_input_node->Index());
// If input to the 1st ReduceMean is a Cast, and the Cast has same consumer count as subCnt + 1
if (graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_input_node, "Cast", {9, 13, 19}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
reduce_mean_input_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
optimizer_utils::CheckOutputEdges(graph, reduce_mean_input_node, static_cast<size_t>(subCnt) + 1)) {
nodes_to_remove.insert(nodes_to_remove.begin(), reduce_mean_input_node);
Expand All @@ -254,7 +254,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const Node* p_cast1 = nullptr;
if (!p_sub_node_dup && sub_node.GetOutputEdgesCount() == 1) {
Node& cast_node = *graph.GetNode(sub_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
cast_node.GetExecutionProviderType() == reduce_mean_node.GetExecutionProviderType() &&
optimizer_utils::CheckOutputEdges(graph, cast_node, 2u) && IsSupportedDataType(cast_node)) {
p_cast1 = &cast_node;
Expand Down Expand Up @@ -353,7 +353,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const Node* p_cast2 = graph_utils::FirstParentByType(pow_node, "Cast");
if (p_cast2 != nullptr && p_cast2 != p_cast1) {
Node& cast_node = *graph.GetNode(p_cast2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) ||
cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) {
continue;
Expand All @@ -371,7 +371,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// can be removed. This is one possible place a Cast Op can exist, that is between Div and Mul nodes.
// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
nodes_to_remove.push_back(*next_node);
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());
Expand Down Expand Up @@ -637,7 +637,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
if (is_gpu_ep && p_pow_input_node) {
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13, 19}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
pow_input_node.GetExecutionProviderType() == pow_node.GetExecutionProviderType() &&
optimizer_utils::CheckOutputEdges(graph, pow_input_node, 2)) {
nodes_to_remove.insert(nodes_to_remove.begin(), pow_input_node);
Expand All @@ -647,7 +647,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr

// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
if (!is_gpu_ep) continue;
nodes_to_remove.push_back(*next_node);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/optimizer/matmul_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace onnxruntime {

namespace {
const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
{"Reshape", {1, 5, 13, 14, 19}},
{"Transpose", {1, 13}}};
{"Reshape", {1, 5, 13, 14, 19, 21, 23, 24, 25}},
{"Transpose", {1, 13, 21, 23, 24, 25}}};
const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}};
} // namespace

Expand Down Expand Up @@ -244,4 +244,4 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime
18 changes: 9 additions & 9 deletions onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_output_channels = (output_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 348 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
Expand Down Expand Up @@ -378,7 +378,7 @@
if ((input_channels % channel_alignment) != 0) {
return;
}
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 381 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size
}
}

Expand Down Expand Up @@ -880,7 +880,7 @@
bn_B.sub(bn_mean);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 883 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

InlinedVector<float> padded_buffer(gsl::narrow<size_t>(nchwc_channels));

Expand Down Expand Up @@ -1149,15 +1149,15 @@
}

void NchwcTransformerImpl::Transform(Node& node) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21, 23, 24, 25})) {
TrackTransposeFromNhwc(node);
}

if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11, 22}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) {
TransformConv(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10, 11, 12}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "AveragePool", {1, 7, 10, 11})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10, 11, 12, 22}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "AveragePool", {1, 7, 10, 11, 19, 22})) {
TransformPool(node);
} else if (node.GetInputEdgesCount() == 0 && node.InputDefs().size() != 0) {
// The following transforms only run when the input edge count has already
Expand All @@ -1176,15 +1176,15 @@
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) {
TransformActivation(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) {
TransformBatchNormalization(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21, 23, 24, 25})) {
TransformTransposeToNhwc(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) {
graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13, 18, 19})) {
TransformResize(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "GlobalMaxPool", {1}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "GlobalAveragePool", {1})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "GlobalMaxPool", {1, 22}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "GlobalAveragePool", {1, 22})) {
// Convert these pooling types only if the input is already in NCHWc format.
TransformPool(node);
}
Expand Down
Loading
Loading