Skip to content

Commit a426837

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Don't set replacement if lhs is in the free symbols of the rhs (pytorch#139250)
Fixes python test/dynamo/test_functions.py FunctionTests.test_is_integer when we turn off specialize float on eager: pytorch#138915 Pull Request resolved: pytorch#139250 Approved by: https://github.com/ezyang
1 parent 754b262 commit a426837

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/dynamo/test_unspec.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,19 @@ def fn(x):
601601
compl_fn = torch.compile(fn, dynamic=True, backend="eager")
602602
self.assertEqual(compl_fn(inputs), fn(inputs))
603603

604+
@torch._dynamo.config.patch(specialize_float=False)
605+
def test_symfloat_no_replacement(self):
606+
# See https://github.com/pytorch/pytorch/pull/139250 for more context
607+
# The high level idea is if we don't want to set a replacement where a
608+
# symbol is on both the right and left side, otherwise we'll end up
609+
# in an infinite self._find recursion.
610+
def fn(t, m):
611+
return 2 * t if m.is_integer() else t
612+
613+
t = torch.tensor([1])
614+
compl_fn = torch.compile(fn, dynamic=True, backend="eager")
615+
self.assertEqual(fn(t, 1.0), compl_fn(t, 1.0))
616+
604617
@torch._dynamo.config.patch(specialize_float=False)
605618
def test_unspec_roundtrip_float_input(self):
606619
def f(x, y):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5591,10 +5591,12 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
55915591
Adds or updates a replacement for a symbol.
55925592
Use this instead of `self.replacements[a] = tgt`.
55935593
"""
5594-
55955594
if tgt == self.replacements.get(a, None):
55965595
return
55975596

5597+
if a in tgt.free_symbols:
5598+
return
5599+
55985600
# Precondition: a == tgt
55995601
assert isinstance(a, sympy.Symbol)
56005602

0 commit comments

Comments
 (0)