Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "sess
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";

// Disables NCHWc layout transformation throughout the graph
// If contrib ops are disabled, NCHWc layout transformation is disabled by default
//
// Option values:
// - "0": NCHWc layout transformation is not disabled. [DEFAULT]
// - "1": NCHWc layout transformation is disabled.
static const char* const kOrtSessionOptionsDisableNchwcLayoutTransformation = "session.disable_nchwc_layout_transformation";

// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
// assigned (i.e., "fallback") to the CPU EP by default.
//
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/mlas/lib/snchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,12 @@ Return Value:

MLAS_THREADED_ROUTINE* ThreadedRoutine;

// TODO(hasesh): Consider adding more implementations for Conv (for example, Im2Col + GEMM)
// to work with NCHWc data layout that provides better performance for large kernel sizes
// on some platforms.
// The current "direct" convolution implementations are bottlenecked by memory bandwidth for
// large kernel sizes on some platforms and lead to poor performance.
// See https://github.com/microsoft/onnxruntime/issues/26992 for more details.
if (WorkBlock.InputChannels >= MlasNchwcGetBlockSize()) {
if (WorkBlock.KernelShape[0] == 1 && WorkBlock.KernelShape[1] == 1 &&
WorkBlock.Padding[0] == 0 && WorkBlock.Padding[1] == 0 &&
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

case TransformerLevel::Level3: {
#ifndef DISABLE_CONTRIB_OPS
// Register the NCHWc layout transformer if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
// Register the NCHWc layout transformer if supported by the platform and if user didn't explicitly disable it.
const bool disable_nchwc = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableNchwcLayoutTransformation, "0") == "1";

if (MlasNchwcGetBlockSize() > 1 && !disable_nchwc) {
transformers.emplace_back(std::make_unique<NchwcTransformer>());
}

Expand Down
23 changes: 21 additions & 2 deletions onnxruntime/core/optimizer/nchwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class NchwcTransformerImpl {
public:
NchwcTransformerImpl(Graph& graph) noexcept : graph_(graph) {}
NchwcTransformerImpl(Graph& graph, const logging::Logger& logger) noexcept : graph_(graph), logger_(logger) {}

void Transform(Node& node);
void Finalize(bool& modified);
Expand Down Expand Up @@ -160,6 +160,8 @@
// NHWC to NCHW format.
Node* transpose_from_nhwc_node_{nullptr};
NodeArg* transpose_from_nhwc_output_arg_{nullptr};

const logging::Logger& logger_;
};

size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
Expand Down Expand Up @@ -335,6 +337,23 @@

const int64_t output_channels = conv_W_tensor_proto->dims(0);
const int64_t input_channels = conv_W_tensor_proto->dims(1);
const int64_t kH = conv_W_tensor_proto->dims(2);
const int64_t kW = conv_W_tensor_proto->dims(3);

// See https://github.com/microsoft/onnxruntime/issues/26992 for more details
// The value of 7 is chosen arbitrarily to catch large kernels.
// This may lead to false warnings to users who are able to run models with
// 7 x 7 or larger kernels efficiently with NCHWc layout.
// But this is better than not informing users at all.
// It may help users to make an informed decision to enable/disable NCHWc layout
// based on benchmarking on their target hardware.
if (kH >= 7 || kW >= 7) {
LOGS(logger_, WARNING) << "NCHWc Conv with large kernel (" << kH << "x" << kW
<< ") detected in node '" << node.Name()
<< "'. Please benchmark your target workload on the target hardware with "
<< "NCHWc layout optimizations enabled (default) and disabled (via session options). "
<< "On certain hardware, large kernel convolutions may perform worse with NCHWc layout.";
}

int64_t group_count;
const auto* group_attr = graph_utils::GetNodeAttribute(node, "group");
Expand All @@ -345,7 +364,7 @@
}

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_output_channels = (output_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 367 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

bool do_reorder_input = true;
bool reorder_filter_OIHWBo = false;
Expand Down Expand Up @@ -378,7 +397,7 @@
if ((input_channels % channel_alignment) != 0) {
return;
}
filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 400 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size
}
}

Expand Down Expand Up @@ -880,7 +899,7 @@
bn_B.sub(bn_mean);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);

Check warning on line 902 in onnxruntime/core/optimizer/nchwc_transformer.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

'~': zero extending 'size_t' to 'int64_t' of greater size

InlinedVector<float> padded_buffer(gsl::narrow<size_t>(nchwc_channels));

Expand Down Expand Up @@ -1227,7 +1246,7 @@
}

Status NchwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
NchwcTransformerImpl impl(graph);
NchwcTransformerImpl impl(graph, logger);
GraphViewer graph_viewer(graph);

for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/test/optimizer/nchwc_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/session/environment.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/framework/tensorprotoutils.h"
#include "test/compare_ortvalue.h"
#include "test/test_environment.h"
Expand Down Expand Up @@ -226,6 +227,48 @@ void NchwcOptimizerTester(const std::function<void(NchwcTestHelper& helper)>& bu

#ifndef DISABLE_CONTRIB_OPS

TEST(NchwcOptimizerTests, DisableNchwcLayoutTransformationSessionOption) {
// Ignore the test if NCHWc is not supported by the platform.
if (MlasNchwcGetBlockSize() <= 1) {
return;
}

std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 13;
Model model("nchwc_disable", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());

NchwcTestHelper helper(model.MainGraph());
{
auto* input_arg = helper.MakeInput<float>({1, 32, 112, 112});
auto* output_arg = helper.MakeOutput();
auto& conv_node = helper.AddConvNode(input_arg, output_arg, {64, 32, 3, 3});
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
}

ASSERT_STATUS_OK(model.MainGraph().Resolve());

std::string model_data;
model.ToProto().SerializeToString(&model_data);

SessionOptions session_options;
session_options.graph_optimization_level = TransformerLevel::Level3;
session_options.session_logid = "NchwcOptimizerDisableTests";
//ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(
// kOrtSessionOptionsDisableNchwcLayoutTransformation, "1"));

InferenceSessionWrapper session{session_options, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast<int>(model_data.size())));
ASSERT_STATUS_OK(session.Initialize());

auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0);
EXPECT_EQ(op_to_count["Conv"], 1);
}

TEST(NchwcOptimizerTests, ConvNchw) {
auto test_case = [&](const std::string& activation_op_type) {
auto build_test_case = [&](NchwcTestHelper& helper) {
Expand Down
Loading