Skip to content

Commit c93e34d

Browse files
Revert "bound sympy accuracy (pytorch#150383)"
This reverts commit 1bc2b2b. Reverted pytorch#150383 on behalf of https://github.com/laithsakka due to big regression ([comment](pytorch#150383 (comment)))
1 parent f443035 commit c93e34d

File tree

2 files changed

+0
-39
lines changed

2 files changed

+0
-39
lines changed

Diff for: test/export/test_export.py

-22
Original file line numberDiff line numberDiff line change
@@ -3105,28 +3105,6 @@ def forward(self, x, y):
31053105
"dy - 6 = 6" not in exc.args[0]
31063106
) # don't suggest fix for non-root dim
31073107

3108-
@testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1)
3109-
@testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0)
3110-
def test_bound_sympy_accuracy(self):
3111-
class Foo(torch.nn.Module):
3112-
def forward(self, x):
3113-
expr = x.shape[0] - (x.shape[0] % 8)
3114-
return torch.empty(expr)
3115-
3116-
ep = export(
3117-
Foo(),
3118-
(torch.randn(13),),
3119-
dynamic_shapes={"x": (Dim("dim", min=2),)},
3120-
)
3121-
3122-
(output,) = ep.graph.output_node().args[0]
3123-
sym_node = output.meta["val"].shape[0].node
3124-
vr = torch.utils._sympy.value_ranges.bound_sympy(
3125-
sym_node.expr,
3126-
sym_node.shape_env.var_to_range,
3127-
)
3128-
self.assertEqual(vr.lower, 0)
3129-
31303108
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
31313109
def test_keep_composite_ops_invalid(self):
31323110
class Foo(torch.nn.Module):

Diff for: torch/utils/_sympy/value_ranges.py

-17
Original file line numberDiff line numberDiff line change
@@ -1004,22 +1004,6 @@ def trunc(x):
10041004
return ValueRanges.increasing_map(x, TruncToFloat)
10051005

10061006

1007-
def _rewrite_for_value_range_analysis(expr: sympy.Expr):
1008-
"""
1009-
Sometimes accuracy of value range analysis can be improved
1010-
with simple rewriting rules.
1011-
"""
1012-
1013-
# Rewrite X - X%Y to (X//Y) * Y.
1014-
x, y = sympy.Wild("x"), sympy.Wild("y")
1015-
expr = expr.replace(
1016-
x - torch.utils._sympy.functions.Mod(x, y),
1017-
torch.utils._sympy.functions.FloorDiv(x, y) * y,
1018-
)
1019-
1020-
return expr
1021-
1022-
10231007
def bound_sympy(
10241008
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
10251009
) -> ValueRanges:
@@ -1063,7 +1047,6 @@ def missing_handler(s):
10631047
vr = ValueRanges.unknown()
10641048
return vr
10651049

1066-
expr = _rewrite_for_value_range_analysis(expr)
10671050
return sympy_interp(
10681051
SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler
10691052
)

0 commit comments

Comments
 (0)