Skip to content

Commit 31cb3f5

Browse files
committed
[ET-VK] Modifying should_squeeze function in SqueezeUnsqueezeInputs to not squeeze if significant axis are all 1 and trailing axis are all > 1.
This diff modifies the `should_squeeze` function in `SqueezeUnsqueezeInputs` to not squeeze (return False) if significant axes are all 1 and trailing axes are all > 1. Differential Revision: [D75483587](https://our.internmc.facebook.com/intern/diff/D75483587/) ghstack-source-id: 286577899 Pull Request resolved: #11177
1 parent a8df947 commit 31cb3f5

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
3232
return shape[1] == 1 and shape[0] > 1
3333
if len(shape) == 4:
3434
# No need to squeeze if all dims are 1 except the width dim
35-
if all(dim == 1 for dim in shape[:-1]):
35+
if shape[0] == shape[1] == shape[2] == 1:
36+
return False
37+
# No need to squeeze if batch and channel dims are 1 and height and width are > 1
38+
if shape[0] == shape[1] == 1 and shape[2] > 1 and shape[3] > 1:
39+
return False
40+
# No need to squeeze if batch dim is 1 and channel, height and width are > 1
41+
if shape[0] == 1 and shape[1] > 1 and shape[2] > 1 and shape[3] > 1:
3642
return False
3743
# Otherwise, check for squeezable dim
3844
return 1 in shape[:-1]

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ def preprocess( # noqa: C901
150150
program = apply_passes(
151151
program,
152152
[
153-
RemoveRedundantOpsTransform(),
154153
AddmmToLinearTransform(),
155154
FuseQuantizedOpsTransform(program),
156155
SqueezeUnsqueezeInputs(),
157156
FuseViewCopyTransform(),
158157
ViewCopyToSqueezeUnsqueezePass(),
159158
FuseBatchNormWithConvPass(program),
160159
FuseClampPass(),
160+
RemoveRedundantOpsTransform(),
161161
],
162162
)
163163

0 commit comments

Comments
 (0)