Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "sess
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";

// Disables NCHWc layout transformation throughout the graph
// If contrib ops are disabled, NCHWc layout transformation is disabled by default
//
// Option values:
// - "0": NCHWc layout transformation is not disabled. [DEFAULT]
// - "1": NCHWc layout transformation is disabled.
static const char* const kOrtSessionOptionsDisableNchwcLayoutTransformation = "session.disable_nchwc_layout_transformation";

// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
// assigned (i.e., "fallback") to the CPU EP by default.
//
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

case TransformerLevel::Level3: {
#ifndef DISABLE_CONTRIB_OPS
// Register the NCHWc layout transformer if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
// Register the NCHWc layout transformer if supported by the platform and if user didn't explicitly disable it.
const bool disable_nchwc = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableNchwcLayoutTransformation, "0") == "1";
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();

if (MlasNchwcGetBlockSize() > 1 && !disable_nchwc) {
transformers.emplace_back(std::make_unique<NchwcTransformer>());
}

Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace onnxruntime {

class NchwcTransformerImpl {
public:
NchwcTransformerImpl(Graph& graph) noexcept : graph_(graph) {}
NchwcTransformerImpl(Graph& graph, const logging::Logger& logger) noexcept : graph_(graph), logger_(logger) {}

void Transform(Node& node);
void Finalize(bool& modified);
Expand Down Expand Up @@ -160,6 +160,8 @@ class NchwcTransformerImpl {
// NHWC to NCHW format.
Node* transpose_from_nhwc_node_{nullptr};
NodeArg* transpose_from_nhwc_output_arg_{nullptr};

const logging::Logger& logger_;
};

size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
Expand Down Expand Up @@ -335,6 +337,16 @@ void NchwcTransformerImpl::TransformConv(Node& node) {

const int64_t output_channels = conv_W_tensor_proto->dims(0);
const int64_t input_channels = conv_W_tensor_proto->dims(1);
const int64_t kH = conv_W_tensor_proto->dims(2);
const int64_t kW = conv_W_tensor_proto->dims(3);

if (kH >= 7 || kW >= 7) {
LOGS(logger_, WARNING) << "NCHWc Conv with large kernel (" << kH << "x" << kW
<< ") detected in node '" << node.Name()
<< "'. Please benchmark your target workload on the target hardware with "
<< "NCHWc layout optimizations enabled (default) and disabled (via session options)."
<< "On certain hardware, large kernel convolutions may perform worse with NCHWc layout.";
}

int64_t group_count;
const auto* group_attr = graph_utils::GetNodeAttribute(node, "group");
Expand Down Expand Up @@ -1227,7 +1239,7 @@ void NchwcTransformerImpl::Finalize(bool& modified) {
}

Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
NchwcTransformerImpl impl(graph);
NchwcTransformerImpl impl(graph, logger);
GraphViewer graph_viewer(graph);

for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
Expand Down
Loading