Skip to content

Commit 5ead965

Browse files
pianpwkpytorchmergebot
authored andcommitted
[export] don't duck size for DIM.AUTO (pytorch#134486)
Summary: apparently DIM.AUTO leads to duck sizing, I didn't catch this. Doing the least intrusive fix possible by using `torch._dynamo.maybe_mark_dynamic()` under the hood. Test Plan: added test Differential Revision: D61809344 Pull Request resolved: pytorch#134486 Approved by: https://github.com/avikchaudhuri
1 parent 30094be commit 5ead965

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/export/test_export.py

+23
Original file line numberDiff line numberDiff line change
@@ -2492,6 +2492,29 @@ def forward(self, x, y):
24922492
gm(torch.randn(33, 4), torch.randn(32, 4))
24932493
gm(torch.randn(128, 4), torch.randn(128, 4))
24942494

2495+
def test_dont_duck_size_for_auto_dynamic(self):
2496+
# for this use case, mark_dynamic() and AUTO should have same effect.
2497+
# check that same symbol gets allocated to both dims without raising constraint violation.
2498+
from torch.export.dynamic_shapes import DIM
2499+
2500+
AUTO, STATIC = DIM.AUTO, DIM.STATIC
2501+
2502+
class Foo(torch.nn.Module):
2503+
def forward(self, x, y):
2504+
# x: [s0, s1], y: [s0 + 1, 4]
2505+
assert y.shape[1] == 4
2506+
assert x.shape[0] == y.shape[0] - 1
2507+
return x * 2, y * 2
2508+
2509+
# duck sizing would make all static based on these sample inputs
2510+
inputs = (torch.randn(4, 4), torch.randn(5, 4))
2511+
shapes = {
2512+
"x": (AUTO, AUTO),
2513+
"y": (AUTO, AUTO),
2514+
}
2515+
ep = export(Foo(), inputs, dynamic_shapes=shapes)
2516+
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
2517+
24952518
@testing.expectedFailureRetraceability # T183144629
24962519
def test_map(self):
24972520
class Module(torch.nn.Module):

torch/export/dynamic_shapes.py

+4
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,8 @@ def _marked_dynamic(tensor, i):
835835
if _marked_dynamic(tensor, i) or dim == DIM.AUTO:
836836
# don't have to specify anything if dynamic
837837
# None also works, since assume_static_by_default=False
838+
if dim == DIM.AUTO:
839+
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
838840
continue
839841
elif isinstance(dim, _Dim):
840842
out[i] = dim
@@ -851,6 +853,8 @@ def _marked_dynamic(tensor, i):
851853
for i, val in enumerate(tensor.shape):
852854
dim = shape[i]
853855
if _marked_dynamic(tensor, i) or dim == DIM.AUTO:
856+
if dim == DIM.AUTO:
857+
torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing
854858
out.append(None)
855859
elif isinstance(dim, _Dim):
856860
out.append(dim)

0 commit comments

Comments
 (0)