Skip to content

Commit 312b0f8

Browse files
authored
Merge pull request #2493 from jsiirola/expr-inplace
Resolve expression loops with inplace operators on Expression objects
2 parents e73b024 + e9d3c7d commit 312b0f8

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

pyomo/core/base/expression.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,21 @@ def expr(self, expr):
195195

196196
def set_value(self, expr):
197197
"""Set the expression on this expression."""
198-
self._expr = as_numeric(expr) if (expr is not None) else None
198+
if expr is None:
199+
self._expr = None
200+
return
201+
expr = as_numeric(expr)
202+
# In-place operators will leave self as an argument. We need to
203+
# replace that with the current expression in order to avoid
204+
# loops in the expression tree.
205+
if expr.is_expression_type():
206+
_args = expr.args
207+
if any(arg is self for arg in _args):
208+
new_args = _args.__class__(
209+
arg.expr if arg is self else arg for arg in _args
210+
)
211+
expr = expr.create_node_with_local_data(new_args)
212+
self._expr = expr
199213

200214
def is_constant(self):
201215
"""A boolean indicating whether this expression is constant."""

pyomo/core/tests/unit/test_expression.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
import pyomo.common.unittest as unittest
1919

20-
from pyomo.environ import ConcreteModel, AbstractModel, Expression, Var, Set, Param, Objective, value, sum_product
20+
from pyomo.environ import (
21+
ConcreteModel, AbstractModel, Expression, Var, Set, Param, Objective,
22+
value, sum_product,
23+
)
2124
from pyomo.core.base.expression import _GeneralExpressionData
25+
from pyomo.core.expr.compare import compare_expressions
2226
from pyomo.common.tee import capture_output
2327

2428
class TestExpressionData(unittest.TestCase):
@@ -902,6 +906,13 @@ def test_iadd(self):
902906
expr += v
903907
self.assertEqual(e.expr, 1)
904908
self.assertEqual(expr(), 2)
909+
# Make sure that using in-place operators on named expressions
910+
# do not create loops inthe expression tree (test #1890)
911+
m.x = Var()
912+
m.y = Var()
913+
m.e.expr = m.x
914+
m.e += m.y
915+
self.assertTrue(compare_expressions(m.e.expr, m.x + m.y))
905916

906917
def test_isub(self):
907918
# make sure simple for loops that look like they
@@ -919,6 +930,13 @@ def test_isub(self):
919930
expr -= v
920931
self.assertEqual(e.expr, 1)
921932
self.assertEqual(expr(), -2)
933+
# Make sure that using in-place operators on named expressions
934+
# do not create loops inthe expression tree (test #1890)
935+
m.x = Var()
936+
m.y = Var()
937+
m.e.expr = m.x
938+
m.e -= m.y
939+
self.assertTrue(compare_expressions(m.e.expr, m.x - m.y))
922940

923941
def test_imul(self):
924942
# make sure simple for loops that look like they
@@ -936,6 +954,13 @@ def test_imul(self):
936954
expr *= v
937955
self.assertEqual(e.expr, 3)
938956
self.assertEqual(expr(), 6)
957+
# Make sure that using in-place operators on named expressions
958+
# do not create loops inthe expression tree (test #1890)
959+
m.x = Var()
960+
m.y = Var()
961+
m.e.expr = m.x
962+
m.e *= m.y
963+
self.assertTrue(compare_expressions(m.e.expr, m.x * m.y))
939964

940965
def test_idiv(self):
941966
# make sure simple for loops that look like they
@@ -968,6 +993,13 @@ def test_idiv(self):
968993
expr /= v
969994
self.assertEqual(e.expr, 3)
970995
self.assertEqual(expr(), 1.5)
996+
# Make sure that using in-place operators on named expressions
997+
# do not create loops inthe expression tree (test #1890)
998+
m.x = Var()
999+
m.y = Var()
1000+
m.e.expr = m.x
1001+
m.e /= m.y
1002+
self.assertTrue(compare_expressions(m.e.expr, m.x / m.y))
9711003

9721004
def test_ipow(self):
9731005
# make sure simple for loops that look like they
@@ -985,6 +1017,13 @@ def test_ipow(self):
9851017
expr **= v
9861018
self.assertEqual(e.expr, 3)
9871019
self.assertEqual(expr(), 9)
1020+
# Make sure that using in-place operators on named expressions
1021+
# do not create loops inthe expression tree (test #1890)
1022+
m.x = Var()
1023+
m.y = Var()
1024+
m.e.expr = m.x
1025+
m.e **= m.y
1026+
self.assertTrue(compare_expressions(m.e.expr, m.x ** m.y))
9881027

9891028
if __name__ == "__main__":
9901029
unittest.main()

0 commit comments

Comments
 (0)