-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[MLAS] Add an implementation an NHWC implementation of convolution to avoid transposes #26834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
d4943e4
1606a1c
f80cc39
6045333
2dd199e
eb026d1
4df9cea
b133782
0c2d1cd
25c0be7
bee0892
a64af7c
bc1ada6
0482150
f9606cd
63d9c55
457513b
b836bd3
d305b8f
891dad5
7acbfcf
878dff6
0a04afc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,5 +26,13 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( | |||||||||||
| .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), | ||||||||||||
| FusedConvFloat); | ||||||||||||
|
|
||||||||||||
| ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( | ||||||||||||
| NhwcFusedConv, | ||||||||||||
| 1, | ||||||||||||
| float, | ||||||||||||
| KernelDefBuilder() | ||||||||||||
|
||||||||||||
| KernelDefBuilder() | |
| KernelDefBuilder() | |
| // Allow the optional "sum" input (index 3) to be reused as the output buffer (index 0), | |
| // consistent with the FusedConv kernel registration. | |
| .MayInplace(3, 0) |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NhwcFusedConv kernel registration (lines 29-35) is not conditionally compiled with USE_KLEIDIAI guards, but the PR description states this is a "KleidiAi specific implementation" that is "only used with KleidiAi (for now)". This is inconsistent with the conditional registration approach used in cpu_contrib_kernels.cc where the declaration is guarded by #ifdef USE_KLEIDIAI. Either this registration should also be conditionally compiled with USE_KLEIDIAI, or if the kernel is meant to work without KleidiAI (using the fallback path), the guards in cpu_contrib_kernels.cc should be removed.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -211,7 +211,15 @@ class FuseConvAddActivationAction : public ReplaceWithNew { | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||
| std::string OpType(const RuntimeState& runtimeState) const override { | ||||||||||||||||||||||||||||||||||||||||
| return (runtimeState.selected_nodes.Target().OpType() == "Conv") ? "FusedConv" : "NhwcFusedConv"; | ||||||||||||||||||||||||||||||||||||||||
| const auto& target = runtimeState.selected_nodes.Target(); | ||||||||||||||||||||||||||||||||||||||||
| const auto* channels_last_attr = graph_utils::GetNodeAttribute(target, "channels_last"); | ||||||||||||||||||||||||||||||||||||||||
| const bool channels_last = channels_last_attr != nullptr && channels_last_attr->i() != 0; | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if (target.OpType() == "Conv") { | ||||||||||||||||||||||||||||||||||||||||
| return channels_last ? "NhwcFusedConv" : "FusedConv"; | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return "NhwcFusedConv"; | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+217
to
+222
|
||||||||||||||||||||||||||||||||||||||||
| if (target.OpType() == "Conv") { | |
| return channels_last ? "NhwcFusedConv" : "FusedConv"; | |
| } | |
| return "NhwcFusedConv"; | |
| const std::string& op_type = target.OpType(); | |
| // If channels_last is set, use NHWC fused convolution regardless of original op type. | |
| if (channels_last) { | |
| return "NhwcFusedConv"; | |
| } | |
| // Without channels_last, convert Conv to FusedConv, and leave other op types unchanged. | |
| if (op_type == "Conv") { | |
| return "FusedConv"; | |
| } | |
| return op_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The registration of NhwcFusedConv kernel is unconditional in cpu_contrib_kernels.cc (line 308), but the kernel declaration is conditionally compiled with USE_KLEIDIAI guards in the same file (lines 21-23). This creates an inconsistency: when USE_KLEIDIAI is not defined, the declaration is absent but the registration still attempts to register the kernel, which will likely cause a compilation error. The registration on line 308 should also be wrapped with #ifdef USE_KLEIDIAI guards to match the conditional declaration.