Skip to content

Commit c696430

Browse files
authored
Merge pull request #3056 from jsiirola/fix-presolve-named-expr
Remove presolve-eliminated variables from named expressions
2 parents ee95302 + bb5cee6 commit c696430

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

pyomo/repn/plugins/nl_writer.py

+9
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,15 @@ def _linear_presolve(self, comp_by_linear_var, lcon_by_linear_nnz, var_bounds):
16911691
if not self.config.linear_presolve:
16921692
return eliminated_cons, eliminated_vars
16931693

1694+
# We need to record all named expressions with linear components
1695+
# so that any eliminated variables are removed from them.
1696+
for expr, info, _ in self.subexpression_cache.values():
1697+
if not info.linear:
1698+
continue
1699+
expr_id = id(expr)
1700+
for _id in info.linear:
1701+
comp_by_linear_var[_id].append((expr_id, info))
1702+
16941703
fixed_vars = [
16951704
_id for _id, (lb, ub) in var_bounds.items() if lb == ub and lb is not None
16961705
]

pyomo/repn/tests/ampl/test_nlv2.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,6 @@ def test_presolve_almost_lower_triangular_nonlinear(self):
14811481
# Note: bounds on x[1] are:
14821482
# min(22/3, 82/17, 23/4, -39/-6) == 4.823529411764706
14831483
# max(2/3, 62/17, 3/4, -19/-6) == 3.6470588235294117
1484-
print(OUT.getvalue())
14851484
self.assertEqual(
14861485
*nl_diff(
14871486
"""g3 1 1 0 # problem unknown
@@ -1558,6 +1557,83 @@ def test_presolve_lower_triangular_out_of_bounds(self):
15581557
nlinfo = nl_writer.NLWriter().write(m, OUT, linear_presolve=True)
15591558
self.assertEqual(LOG.getvalue(), "")
15601559

1560+
def test_presolve_named_expressions(self):
1561+
# Test from #3055
1562+
m = pyo.ConcreteModel()
1563+
m.x = pyo.Var([1, 2, 3], initialize=1, bounds=(0, 10))
1564+
m.subexpr = pyo.Expression(pyo.Integers)
1565+
m.subexpr[1] = m.x[1] + m.x[2]
1566+
m.eq = pyo.Constraint(pyo.Integers)
1567+
m.eq[1] = m.x[1] == 7
1568+
m.eq[2] = m.x[3] == 0.1 * m.subexpr[1] * m.x[2]
1569+
m.obj = pyo.Objective(expr=m.x[1] ** 2 + m.x[2] ** 2 + m.x[3] ** 3)
1570+
1571+
OUT = io.StringIO()
1572+
with LoggingIntercept() as LOG:
1573+
nlinfo = nl_writer.NLWriter().write(
1574+
m, OUT, symbolic_solver_labels=True, linear_presolve=True
1575+
)
1576+
self.assertEqual(LOG.getvalue(), "")
1577+
1578+
self.assertEqual(
1579+
nlinfo.eliminated_vars, [(m.x[1], nl_writer.AMPLRepn(7, {}, None))]
1580+
)
1581+
1582+
self.assertEqual(
1583+
*nl_diff(
1584+
"""g3 1 1 0 # problem unknown
1585+
2 1 1 0 1 # vars, constraints, objectives, ranges, eqns
1586+
1 1 0 0 0 0 # nonlinear constrs, objs; ccons: lin, nonlin, nd, nzlb
1587+
0 0 # network constraints: nonlinear, linear
1588+
1 2 1 # nonlinear vars in constraints, objectives, both
1589+
0 0 0 1 # linear network variables; functions; arith, flags
1590+
0 0 0 0 0 # discrete variables: binary, integer, nonlinear (b,c,o)
1591+
2 2 # nonzeros in Jacobian, obj. gradient
1592+
5 4 # max name lengths: constraints, variables
1593+
0 0 0 1 0 # common exprs: b,c,o,c1,o1
1594+
V2 1 1 #subexpr[1]
1595+
0 1
1596+
n7.0
1597+
C0 #eq[2]
1598+
o16 #-
1599+
o2 #*
1600+
o2 #*
1601+
n0.1
1602+
v2 #subexpr[1]
1603+
v0 #x[2]
1604+
O0 0 #obj
1605+
o54 # sumlist
1606+
3 # (n)
1607+
o5 #^
1608+
n7.0
1609+
n2
1610+
o5 #^
1611+
v0 #x[2]
1612+
n2
1613+
o5 #^
1614+
v1 #x[3]
1615+
n3
1616+
x2 # initial guess
1617+
0 1 #x[2]
1618+
1 1 #x[3]
1619+
r #1 ranges (rhs's)
1620+
4 0 #eq[2]
1621+
b #2 bounds (on variables)
1622+
0 0 10 #x[2]
1623+
0 0 10 #x[3]
1624+
k1 #intermediate Jacobian column lengths
1625+
1
1626+
J0 2 #eq[2]
1627+
0 0
1628+
1 1
1629+
G0 2 #obj
1630+
0 0
1631+
1 0
1632+
""",
1633+
OUT.getvalue(),
1634+
)
1635+
)
1636+
15611637
def test_scaling(self):
15621638
m = pyo.ConcreteModel()
15631639
m.x = pyo.Var(initialize=0)
@@ -1665,7 +1741,7 @@ def test_scaling(self):
16651741
self.assertEqual(LOG.getvalue(), "")
16661742

16671743
nl2 = OUT.getvalue()
1668-
print(nl2)
1744+
16691745
self.assertEqual(
16701746
*nl_diff(
16711747
"""g3 1 1 0 # problem unknown
@@ -1759,7 +1835,7 @@ def test_named_expressions(self):
17591835

17601836
OUT = io.StringIO()
17611837
nl_writer.NLWriter().write(m, OUT, symbolic_solver_labels=True)
1762-
print(OUT.getvalue())
1838+
17631839
self.assertEqual(
17641840
*nl_diff(
17651841
"""g3 1 1 0 # problem unknown

0 commit comments

Comments
 (0)