Skip to content

Commit 1b5aa97

Browse files
authored
Merge pull request #3186 from Robbybp/scc-performance-2
Improve performance of `solve_strongly_connected_components` for models with named expressions
2 parents b06ddea + 4f1e20c commit 1b5aa97

File tree

3 files changed

+142
-29
lines changed

3 files changed

+142
-29
lines changed

Diff for: pyomo/contrib/incidence_analysis/scc_solver.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
IncidenceGraphInterface,
1919
_generate_variables_in_constraints,
2020
)
21+
from pyomo.contrib.incidence_analysis.config import IncidenceMethod
2122

2223

2324
_log = logging.getLogger(__name__)
2425

2526

2627
def generate_strongly_connected_components(
27-
constraints, variables=None, include_fixed=False
28+
constraints, variables=None, include_fixed=False, igraph=None
2829
):
2930
"""Yield in order ``_BlockData`` that each contain the variables and
3031
constraints of a single diagonal block in a block lower triangularization
@@ -41,9 +42,12 @@ def generate_strongly_connected_components(
4142
variables: List of Pyomo variable data objects
4243
Variables that may participate in strongly connected components.
4344
If not provided, all variables in the constraints will be used.
44-
include_fixed: Bool
45+
include_fixed: Bool, optional
4546
Indicates whether fixed variables will be included when
4647
identifying variables in constraints.
48+
igraph: IncidenceGraphInterface, optional
49+
Incidence graph containing (at least) the provided constraints
50+
and variables.
4751
4852
Yields
4953
------
@@ -55,11 +59,17 @@ def generate_strongly_connected_components(
5559
"""
5660
if variables is None:
5761
variables = list(
58-
_generate_variables_in_constraints(constraints, include_fixed=include_fixed)
62+
_generate_variables_in_constraints(
63+
constraints,
64+
include_fixed=include_fixed,
65+
method=IncidenceMethod.ampl_repn,
66+
)
5967
)
6068

6169
assert len(variables) == len(constraints)
62-
igraph = IncidenceGraphInterface()
70+
if igraph is None:
71+
igraph = IncidenceGraphInterface()
72+
6373
var_blocks, con_blocks = igraph.block_triangularize(
6474
variables=variables, constraints=constraints
6575
)
@@ -73,7 +83,7 @@ def generate_strongly_connected_components(
7383

7484

7585
def solve_strongly_connected_components(
76-
block, solver=None, solve_kwds=None, calc_var_kwds=None
86+
block, *, solver=None, solve_kwds=None, use_calc_var=True, calc_var_kwds=None
7787
):
7888
"""Solve a square system of variables and equality constraints by
7989
solving strongly connected components individually.
@@ -98,6 +108,9 @@ def solve_strongly_connected_components(
98108
a solve method.
99109
solve_kwds: Dictionary
100110
Keyword arguments for the solver's solve method
111+
use_calc_var: Bool
112+
Whether to use ``calculate_variable_from_constraint`` for one-by-one
113+
square system solves
101114
calc_var_kwds: Dictionary
102115
Keyword arguments for calculate_variable_from_constraint
103116
@@ -112,23 +125,28 @@ def solve_strongly_connected_components(
112125
calc_var_kwds = {}
113126

114127
igraph = IncidenceGraphInterface(
115-
block, active=True, include_fixed=False, include_inequality=False
128+
block,
129+
active=True,
130+
include_fixed=False,
131+
include_inequality=False,
132+
method=IncidenceMethod.ampl_repn,
116133
)
117134
constraints = igraph.constraints
118135
variables = igraph.variables
119136

120137
res_list = []
121138
log_blocks = _log.isEnabledFor(logging.DEBUG)
122-
for scc, inputs in generate_strongly_connected_components(constraints, variables):
123-
with TemporarySubsystemManager(to_fix=inputs):
139+
for scc, inputs in generate_strongly_connected_components(
140+
constraints, variables, igraph=igraph
141+
):
142+
with TemporarySubsystemManager(to_fix=inputs, remove_bounds_on_fix=True):
124143
N = len(scc.vars)
125-
if N == 1:
144+
if N == 1 and use_calc_var:
126145
if log_blocks:
127146
_log.debug(f"Solving 1x1 block: {scc.cons[0].name}.")
128147
results = calculate_variable_from_constraint(
129148
scc.vars[0], scc.cons[0], **calc_var_kwds
130149
)
131-
res_list.append(results)
132150
else:
133151
if solver is None:
134152
var_names = [var.name for var in scc.vars.values()][:10]
@@ -142,5 +160,5 @@ def solve_strongly_connected_components(
142160
if log_blocks:
143161
_log.debug(f"Solving {N}x{N} block.")
144162
results = solver.solve(scc, **solve_kwds)
145-
res_list.append(results)
163+
res_list.append(results)
146164
return res_list

Diff for: pyomo/util/subsystems.py

+62-15
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,35 @@
1717

1818
from pyomo.core.base.constraint import Constraint
1919
from pyomo.core.base.expression import Expression
20+
from pyomo.core.base.objective import Objective
2021
from pyomo.core.base.external import ExternalFunction
2122
from pyomo.core.expr.visitor import StreamBasedExpressionVisitor
2223
from pyomo.core.expr.numeric_expr import ExternalFunctionExpression
23-
from pyomo.core.expr.numvalue import native_types
24+
from pyomo.core.expr.numvalue import native_types, NumericValue
2425

2526

2627
class _ExternalFunctionVisitor(StreamBasedExpressionVisitor):
28+
def __init__(self, descend_into_named_expressions=True):
29+
super().__init__()
30+
self._descend_into_named_expressions = descend_into_named_expressions
31+
self.named_expressions = []
32+
2733
def initializeWalker(self, expr):
2834
self._functions = []
2935
self._seen = set()
3036
return True, None
3137

38+
def beforeChild(self, parent, child, index):
39+
if child.__class__ in native_types:
40+
return False, None
41+
elif (
42+
not self._descend_into_named_expressions
43+
and child.is_named_expression_type()
44+
):
45+
self.named_expressions.append(child)
46+
return False, None
47+
return True, None
48+
3249
def exitNode(self, node, data):
3350
if type(node) is ExternalFunctionExpression:
3451
if id(node) not in self._seen:
@@ -38,26 +55,35 @@ def exitNode(self, node, data):
3855
def finalizeResult(self, result):
3956
return self._functions
4057

41-
def enterNode(self, node):
42-
pass
43-
44-
def acceptChildResult(self, node, data, child_result, child_idx):
45-
pass
46-
47-
def acceptChildResult(self, node, data, child_result, child_idx):
48-
if child_result.__class__ in native_types:
49-
return False, None
50-
return child_result.is_expression_type(), None
51-
5258

5359
def identify_external_functions(expr):
5460
yield from _ExternalFunctionVisitor().walk_expression(expr)
5561

5662

5763
def add_local_external_functions(block):
5864
ef_exprs = []
59-
for comp in block.component_data_objects((Constraint, Expression), active=True):
60-
ef_exprs.extend(identify_external_functions(comp.expr))
65+
named_expressions = []
66+
visitor = _ExternalFunctionVisitor(descend_into_named_expressions=False)
67+
for comp in block.component_data_objects(
68+
(Constraint, Expression, Objective), active=True
69+
):
70+
ef_exprs.extend(visitor.walk_expression(comp.expr))
71+
named_expr_set = ComponentSet(visitor.named_expressions)
72+
# List of unique named expressions
73+
named_expressions = list(named_expr_set)
74+
while named_expressions:
75+
expr = named_expressions.pop()
76+
# Clear named expression cache so we don't re-check named expressions
77+
# we've seen before.
78+
visitor.named_expressions.clear()
79+
ef_exprs.extend(visitor.walk_expression(expr))
80+
# Only add to the stack named expressions that we have
81+
# not encountered yet.
82+
for local_expr in visitor.named_expressions:
83+
if local_expr not in named_expr_set:
84+
named_expressions.append(local_expr)
85+
named_expr_set.add(local_expr)
86+
6187
unique_functions = []
6288
fcn_set = set()
6389
for expr in ef_exprs:
@@ -148,7 +174,14 @@ class TemporarySubsystemManager(object):
148174
149175
"""
150176

151-
def __init__(self, to_fix=None, to_deactivate=None, to_reset=None, to_unfix=None):
177+
def __init__(
178+
self,
179+
to_fix=None,
180+
to_deactivate=None,
181+
to_reset=None,
182+
to_unfix=None,
183+
remove_bounds_on_fix=False,
184+
):
152185
"""
153186
Arguments
154187
---------
@@ -168,6 +201,8 @@ def __init__(self, to_fix=None, to_deactivate=None, to_reset=None, to_unfix=None
168201
List of var data objects to be temporarily unfixed. These are
169202
restored to their original status on exit from this object's
170203
context manager.
204+
remove_bounds_on_fix: Bool
205+
Whether bounds should be removed temporarily for fixed variables
171206
172207
"""
173208
if to_fix is None:
@@ -194,6 +229,8 @@ def __init__(self, to_fix=None, to_deactivate=None, to_reset=None, to_unfix=None
194229
self._con_was_active = None
195230
self._comp_original_value = None
196231
self._var_was_unfixed = None
232+
self._remove_bounds_on_fix = remove_bounds_on_fix
233+
self._fixed_var_bounds = None
197234

198235
def __enter__(self):
199236
to_fix = self._vars_to_fix
@@ -203,8 +240,13 @@ def __enter__(self):
203240
self._var_was_fixed = [(var, var.fixed) for var in to_fix + to_unfix]
204241
self._con_was_active = [(con, con.active) for con in to_deactivate]
205242
self._comp_original_value = [(comp, comp.value) for comp in to_set]
243+
self._fixed_var_bounds = [(var.lb, var.ub) for var in to_fix]
206244

207245
for var in self._vars_to_fix:
246+
if self._remove_bounds_on_fix:
247+
# TODO: Potentially override var.domain as well?
248+
var.setlb(None)
249+
var.setub(None)
208250
var.fix()
209251

210252
for con in self._cons_to_deactivate:
@@ -223,6 +265,11 @@ def __exit__(self, ex_type, ex_val, ex_bt):
223265
var.fix()
224266
else:
225267
var.unfix()
268+
if self._remove_bounds_on_fix:
269+
for var, (lb, ub) in zip(self._vars_to_fix, self._fixed_var_bounds):
270+
var.setlb(lb)
271+
var.setub(ub)
272+
226273
for con, was_active in self._con_was_active:
227274
if was_active:
228275
con.activate()

Diff for: pyomo/util/tests/test_subsystems.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -292,17 +292,29 @@ def test_generate_dont_fix_inputs_with_fixed_var(self):
292292
self.assertFalse(m.v3.fixed)
293293
self.assertTrue(m.v4.fixed)
294294

295-
def _make_model_with_external_functions(self):
295+
def _make_model_with_external_functions(self, named_expressions=False):
296296
m = pyo.ConcreteModel()
297297
gsl = find_GSL()
298298
m.bessel = pyo.ExternalFunction(library=gsl, function="gsl_sf_bessel_J0")
299299
m.fermi = pyo.ExternalFunction(library=gsl, function="gsl_sf_fermi_dirac_m1")
300300
m.v1 = pyo.Var(initialize=1.0)
301301
m.v2 = pyo.Var(initialize=2.0)
302302
m.v3 = pyo.Var(initialize=3.0)
303+
if named_expressions:
304+
m.subexpr = pyo.Expression(pyo.PositiveIntegers)
305+
m.subexpr[1] = 2 * m.fermi(m.v1)
306+
m.subexpr[2] = m.bessel(m.v1) - m.bessel(m.v2)
307+
m.subexpr[3] = m.subexpr[2] + m.v3**2
308+
subexpr1 = m.subexpr[1]
309+
subexpr2 = m.subexpr[2]
310+
subexpr3 = m.subexpr[3]
311+
else:
312+
subexpr1 = 2 * m.fermi(m.v1)
313+
subexpr2 = m.bessel(m.v1) - m.bessel(m.v2)
314+
subexpr3 = subexpr2 + m.v3**2
303315
m.con1 = pyo.Constraint(expr=m.v1 == 0.5)
304-
m.con2 = pyo.Constraint(expr=2 * m.fermi(m.v1) + m.v2**2 - m.v3 == 1.0)
305-
m.con3 = pyo.Constraint(expr=m.bessel(m.v1) - m.bessel(m.v2) + m.v3**2 == 2.0)
316+
m.con2 = pyo.Constraint(expr=subexpr1 + m.v2**2 - m.v3 == 1.0)
317+
m.con3 = pyo.Constraint(expr=subexpr3 == 2.0)
306318
return m
307319

308320
@unittest.skipUnless(find_GSL(), "Could not find the AMPL GSL library")
@@ -329,6 +341,15 @@ def test_identify_external_functions(self):
329341
pred_fcn_data = {(gsl, "gsl_sf_bessel_J0"), (gsl, "gsl_sf_fermi_dirac_m1")}
330342
self.assertEqual(fcn_data, pred_fcn_data)
331343

344+
@unittest.skipUnless(find_GSL(), "Could not find the AMPL GSL library")
345+
def test_local_external_functions_with_named_expressions(self):
346+
m = self._make_model_with_external_functions(named_expressions=True)
347+
variables = list(m.component_data_objects(pyo.Var))
348+
constraints = list(m.component_data_objects(pyo.Constraint, active=True))
349+
b = create_subsystem_block(constraints, variables)
350+
self.assertTrue(isinstance(b._gsl_sf_bessel_J0, pyo.ExternalFunction))
351+
self.assertTrue(isinstance(b._gsl_sf_fermi_dirac_m1, pyo.ExternalFunction))
352+
332353
def _solve_ef_model_with_ipopt(self):
333354
m = self._make_model_with_external_functions()
334355
ipopt = pyo.SolverFactory("ipopt")
@@ -362,6 +383,33 @@ def test_with_external_function(self):
362383
self.assertAlmostEqual(m.v2.value, m_full.v2.value)
363384
self.assertAlmostEqual(m.v3.value, m_full.v3.value)
364385

386+
@unittest.skipUnless(find_GSL(), "Could not find the AMPL GSL library")
387+
@unittest.skipUnless(
388+
pyo.SolverFactory("ipopt").available(), "ipopt is not available"
389+
)
390+
def test_with_external_function_in_named_expression(self):
391+
m = self._make_model_with_external_functions(named_expressions=True)
392+
subsystem = ([m.con2, m.con3], [m.v2, m.v3])
393+
394+
m.v1.set_value(0.5)
395+
block = create_subsystem_block(*subsystem)
396+
ipopt = pyo.SolverFactory("ipopt")
397+
with TemporarySubsystemManager(to_fix=list(block.input_vars.values())):
398+
ipopt.solve(block)
399+
400+
# Correct values obtained by solving with Ipopt directly
401+
# in another script.
402+
self.assertEqual(m.v1.value, 0.5)
403+
self.assertFalse(m.v1.fixed)
404+
self.assertAlmostEqual(m.v2.value, 1.04816, delta=1e-5)
405+
self.assertAlmostEqual(m.v3.value, 1.34356, delta=1e-5)
406+
407+
# Result obtained by solving the full system
408+
m_full = self._solve_ef_model_with_ipopt()
409+
self.assertAlmostEqual(m.v1.value, m_full.v1.value)
410+
self.assertAlmostEqual(m.v2.value, m_full.v2.value)
411+
self.assertAlmostEqual(m.v3.value, m_full.v3.value)
412+
365413
@unittest.skipUnless(find_GSL(), "Could not find the AMPL GSL library")
366414
def test_external_function_with_potential_name_collision(self):
367415
m = self._make_model_with_external_functions()

0 commit comments

Comments
 (0)