Skip to content

Commit 6f738d6

Browse files
lucylqpytorchmergebot
authored andcommitted
Remove early exit in constant_pad_nd for export (pytorch#132679)
Summary: Remove the early exit for padding when padding = [0, 0, 0, 0]. This prevents export from specializing when all padding=0, allowing export when all padding >= 0. Specialization will still happen for negative padding. This change will be used to export image preprocess for multimodal models, where images of dynamic shape are padded. As images are of dynamic shape, we can't be sure if padding will be required or not. Padding is guaranteed to be non-negative. Preprocess code: pytorch/torchtune#1242 Note: the alternative is to wrap padding in a custom op, which isn't ideal given the custom op will contain the same impl as constant_pad_nd. Test Plan: ci Differential Revision: D60687727 Pull Request resolved: pytorch#132679 Approved by: https://github.com/ezyang
1 parent 9a998d9 commit 6f738d6

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torch/_refs/__init__.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2882,8 +2882,16 @@ def constant_pad_nd(
28822882
if pad[pad_idx + 1] < 0:
28832883
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
28842884

2885-
# if none of the pads are positive we can just return the result
2886-
if builtins.all(p <= 0 for p in pad):
2885+
# If all the pads are negative we can return the result.
2886+
# Avoid early exiting if all pads = 0 to prevent specialization on export.
2887+
# During export, raw if statements are specialized on the input, meaning
2888+
# that we lose a branch depending on the example input used to export.
2889+
# Here, this is either the case where all pads = 0, or the case where at
2890+
# least one pad > 0 and the rest are >= 0.
2891+
# Avoiding the early exit when all pads = 0 ensures we can export
2892+
# constant_pad_nd for cases when all pads >= 0.
2893+
# Note: if any pads are negative, this code specializes due to the if statements above.
2894+
if builtins.all(p < 0 for p in pad):
28872895
return c_input.clone()
28882896

28892897
new_shape = list(input_sizes[:l_diff])
@@ -2916,11 +2924,11 @@ def constant_pad_nd(
29162924
c_output = output
29172925
for i in range(l_diff, l_inp):
29182926
pad_idx = 2 * (l_inp - i - 1)
2919-
if pad[pad_idx] > 0:
2927+
if pad[pad_idx] >= 0:
29202928
c_output = c_output.narrow(
29212929
i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
29222930
)
2923-
if pad[pad_idx + 1] > 0:
2931+
if pad[pad_idx + 1] >= 0:
29242932
c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
29252933

29262934
prims.copy_to(c_output, c_input)

0 commit comments

Comments
 (0)