-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[MLAS] Add an implementation an NHWC implementation of convolution to avoid transposes #26834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d4943e4
1606a1c
f80cc39
6045333
2dd199e
eb026d1
4df9cea
b133782
0c2d1cd
25c0be7
bee0892
a64af7c
bc1ada6
0482150
f9606cd
63d9c55
457513b
b836bd3
d305b8f
891dad5
7acbfcf
878dff6
0a04afc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,5 +26,13 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( | |
| .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), | ||
| FusedConvFloat); | ||
|
|
||
| ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( | ||
| NhwcFusedConv, | ||
| 1, | ||
| float, | ||
| KernelDefBuilder() | ||
| .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), | ||
| FusedConvFloat); | ||
|
Comment on lines
+29
to
+35
|
||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -211,7 +211,15 @@ class FuseConvAddActivationAction : public ReplaceWithNew { | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||
| std::string OpType(const RuntimeState& runtimeState) const override { | ||||||||||||||||||||||||||||||||||||||||
| return (runtimeState.selected_nodes.Target().OpType() == "Conv") ? "FusedConv" : "NhwcFusedConv"; | ||||||||||||||||||||||||||||||||||||||||
| const auto& target = runtimeState.selected_nodes.Target(); | ||||||||||||||||||||||||||||||||||||||||
| const auto* channels_last_attr = graph_utils::GetNodeAttribute(target, "channels_last"); | ||||||||||||||||||||||||||||||||||||||||
| const bool channels_last = channels_last_attr != nullptr && channels_last_attr->i() != 0; | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if (target.OpType() == "Conv") { | ||||||||||||||||||||||||||||||||||||||||
| return channels_last ? "NhwcFusedConv" : "FusedConv"; | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return "NhwcFusedConv"; | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+217
to
+222
|
||||||||||||||||||||||||||||||||||||||||
| if (target.OpType() == "Conv") { | |
| return channels_last ? "NhwcFusedConv" : "FusedConv"; | |
| } | |
| return "NhwcFusedConv"; | |
| const std::string& op_type = target.OpType(); | |
| // If channels_last is set, use NHWC fused convolution regardless of original op type. | |
| if (channels_last) { | |
| return "NhwcFusedConv"; | |
| } | |
| // Without channels_last, convert Conv to FusedConv, and leave other op types unchanged. | |
| if (op_type == "Conv") { | |
| return "FusedConv"; | |
| } | |
| return op_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NhwcFusedConv kernel registration is missing the MayInplace hint that is present in the FusedConv registration. The FusedConv kernel uses .MayInplace(3, 0) to allow the optional "sum" input (index 3) to be reused as the output buffer (index 0) for efficiency. However, NhwcFusedConv does not include this hint. This means that even though the code in conv.cc handles the Sum input for channels_last mode, the allocation planner cannot optimize memory usage by reusing the Sum buffer for the output when using NhwcFusedConv. Consider adding .MayInplace(3, 0) to the NhwcFusedConv kernel builder to maintain consistency and enable the same memory optimization.