Skip to content

Commit 69082ac

Browse files
authored
Merge pull request #3194 from jsiirola/dispatcher-simplification
Support "default" dispatchers in `ExitNodeDispatcher`
2 parents cc33f35 + 440bf86 commit 69082ac

File tree

4 files changed

+71
-153
lines changed

4 files changed

+71
-153
lines changed

Diff for: pyomo/repn/linear.py

+23-55
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ def _handle_negation_ANY(visitor, node, arg):
200200

201201

202202
_exit_node_handlers[NegationExpression] = {
203+
None: _handle_negation_ANY,
203204
(_CONSTANT,): _handle_negation_constant,
204-
(_LINEAR,): _handle_negation_ANY,
205-
(_GENERAL,): _handle_negation_ANY,
206205
}
207206

208207
#
@@ -211,20 +210,18 @@ def _handle_negation_ANY(visitor, node, arg):
211210

212211

213212
def _handle_product_constant_constant(visitor, node, arg1, arg2):
214-
_, arg1 = arg1
215-
_, arg2 = arg2
216-
ans = arg1 * arg2
213+
ans = arg1[1] * arg2[1]
217214
if ans != ans:
218-
if not arg1 or not arg2:
215+
if not arg1[1] or not arg2[1]:
219216
deprecation_warning(
220-
f"Encountered {str(arg1)}*{str(arg2)} in expression tree. "
217+
f"Encountered {str(arg1[1])}*{str(arg2[1])} in expression tree. "
221218
"Mapping the NaN result to 0 for compatibility "
222219
"with the lp_v1 writer. In the future, this NaN "
223220
"will be preserved/emitted to comply with IEEE-754.",
224221
version='6.6.0',
225222
)
226-
return _, 0
227-
return _, arg1 * arg2
223+
return _CONSTANT, 0
224+
return _CONSTANT, ans
228225

229226

230227
def _handle_product_constant_ANY(visitor, node, arg1, arg2):
@@ -276,15 +273,12 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):
276273

277274

278275
_exit_node_handlers[ProductExpression] = {
276+
None: _handle_product_nonlinear,
279277
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
280278
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
281279
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
282280
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
283-
(_LINEAR, _LINEAR): _handle_product_nonlinear,
284-
(_LINEAR, _GENERAL): _handle_product_nonlinear,
285281
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
286-
(_GENERAL, _LINEAR): _handle_product_nonlinear,
287-
(_GENERAL, _GENERAL): _handle_product_nonlinear,
288282
}
289283
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]
290284

@@ -309,24 +303,18 @@ def _handle_division_nonlinear(visitor, node, arg1, arg2):
309303

310304

311305
_exit_node_handlers[DivisionExpression] = {
306+
None: _handle_division_nonlinear,
312307
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
313-
(_CONSTANT, _LINEAR): _handle_division_nonlinear,
314-
(_CONSTANT, _GENERAL): _handle_division_nonlinear,
315308
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
316-
(_LINEAR, _LINEAR): _handle_division_nonlinear,
317-
(_LINEAR, _GENERAL): _handle_division_nonlinear,
318309
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
319-
(_GENERAL, _LINEAR): _handle_division_nonlinear,
320-
(_GENERAL, _GENERAL): _handle_division_nonlinear,
321310
}
322311

323312
#
324313
# EXPONENTIATION handlers
325314
#
326315

327316

328-
def _handle_pow_constant_constant(visitor, node, *args):
329-
arg1, arg2 = args
317+
def _handle_pow_constant_constant(visitor, node, arg1, arg2):
330318
ans = apply_node_operation(node, (arg1[1], arg2[1]))
331319
if ans.__class__ in native_complex_types:
332320
ans = complex_number_error(ans, visitor, node)
@@ -358,15 +346,10 @@ def _handle_pow_nonlinear(visitor, node, arg1, arg2):
358346

359347

360348
_exit_node_handlers[PowExpression] = {
349+
None: _handle_pow_nonlinear,
361350
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
362-
(_CONSTANT, _LINEAR): _handle_pow_nonlinear,
363-
(_CONSTANT, _GENERAL): _handle_pow_nonlinear,
364351
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
365-
(_LINEAR, _LINEAR): _handle_pow_nonlinear,
366-
(_LINEAR, _GENERAL): _handle_pow_nonlinear,
367352
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
368-
(_GENERAL, _LINEAR): _handle_pow_nonlinear,
369-
(_GENERAL, _GENERAL): _handle_pow_nonlinear,
370353
}
371354

372355
#
@@ -389,9 +372,8 @@ def _handle_unary_nonlinear(visitor, node, arg):
389372

390373

391374
_exit_node_handlers[UnaryFunctionExpression] = {
375+
None: _handle_unary_nonlinear,
392376
(_CONSTANT,): _handle_unary_constant,
393-
(_LINEAR,): _handle_unary_nonlinear,
394-
(_GENERAL,): _handle_unary_nonlinear,
395377
}
396378
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]
397379

@@ -414,9 +396,8 @@ def _handle_named_ANY(visitor, node, arg1):
414396

415397

416398
_exit_node_handlers[Expression] = {
399+
None: _handle_named_ANY,
417400
(_CONSTANT,): _handle_named_constant,
418-
(_LINEAR,): _handle_named_ANY,
419-
(_GENERAL,): _handle_named_ANY,
420401
}
421402

422403
#
@@ -449,12 +430,7 @@ def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3):
449430
return _GENERAL, ans
450431

451432

452-
_exit_node_handlers[Expr_ifExpression] = {
453-
(i, j, k): _handle_expr_if_nonlinear
454-
for i in (_LINEAR, _GENERAL)
455-
for j in (_CONSTANT, _LINEAR, _GENERAL)
456-
for k in (_CONSTANT, _LINEAR, _GENERAL)
457-
}
433+
_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
458434
for j in (_CONSTANT, _LINEAR, _GENERAL):
459435
for k in (_CONSTANT, _LINEAR, _GENERAL):
460436
_exit_node_handlers[Expr_ifExpression][_CONSTANT, j, k] = _handle_expr_if_const
@@ -487,11 +463,9 @@ def _handle_equality_general(visitor, node, arg1, arg2):
487463

488464

489465
_exit_node_handlers[EqualityExpression] = {
490-
(i, j): _handle_equality_general
491-
for i in (_CONSTANT, _LINEAR, _GENERAL)
492-
for j in (_CONSTANT, _LINEAR, _GENERAL)
466+
None: _handle_equality_general,
467+
(_CONSTANT, _CONSTANT): _handle_equality_const,
493468
}
494-
_exit_node_handlers[EqualityExpression][_CONSTANT, _CONSTANT] = _handle_equality_const
495469

496470

497471
def _handle_inequality_const(visitor, node, arg1, arg2):
@@ -517,13 +491,9 @@ def _handle_inequality_general(visitor, node, arg1, arg2):
517491

518492

519493
_exit_node_handlers[InequalityExpression] = {
520-
(i, j): _handle_inequality_general
521-
for i in (_CONSTANT, _LINEAR, _GENERAL)
522-
for j in (_CONSTANT, _LINEAR, _GENERAL)
494+
None: _handle_inequality_general,
495+
(_CONSTANT, _CONSTANT): _handle_inequality_const,
523496
}
524-
_exit_node_handlers[InequalityExpression][
525-
_CONSTANT, _CONSTANT
526-
] = _handle_inequality_const
527497

528498

529499
def _handle_ranged_const(visitor, node, arg1, arg2, arg3):
@@ -554,14 +524,9 @@ def _handle_ranged_general(visitor, node, arg1, arg2, arg3):
554524

555525

556526
_exit_node_handlers[RangedExpression] = {
557-
(i, j, k): _handle_ranged_general
558-
for i in (_CONSTANT, _LINEAR, _GENERAL)
559-
for j in (_CONSTANT, _LINEAR, _GENERAL)
560-
for k in (_CONSTANT, _LINEAR, _GENERAL)
527+
None: _handle_ranged_general,
528+
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
561529
}
562-
_exit_node_handlers[RangedExpression][
563-
_CONSTANT, _CONSTANT, _CONSTANT
564-
] = _handle_ranged_const
565530

566531

567532
class LinearBeforeChildDispatcher(BeforeChildDispatcher):
@@ -754,7 +719,10 @@ def _initialize_exit_node_dispatcher(exit_handlers):
754719
exit_dispatcher = {}
755720
for cls, handlers in exit_handlers.items():
756721
for args, fcn in handlers.items():
757-
exit_dispatcher[(cls, *args)] = fcn
722+
if args is None:
723+
exit_dispatcher[cls] = fcn
724+
else:
725+
exit_dispatcher[(cls, *args)] = fcn
758726
return exit_dispatcher
759727

760728

Diff for: pyomo/repn/quadratic.py

+16-73
Original file line numberDiff line numberDiff line change
@@ -277,119 +277,62 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):
277277

278278
_exit_node_handlers[ProductExpression].update(
279279
{
280+
None: _handle_product_nonlinear,
280281
(_CONSTANT, _QUADRATIC): linear._handle_product_constant_ANY,
281-
(_LINEAR, _QUADRATIC): _handle_product_nonlinear,
282-
(_QUADRATIC, _QUADRATIC): _handle_product_nonlinear,
283-
(_GENERAL, _QUADRATIC): _handle_product_nonlinear,
284282
(_QUADRATIC, _CONSTANT): linear._handle_product_ANY_constant,
285-
(_QUADRATIC, _LINEAR): _handle_product_nonlinear,
286-
(_QUADRATIC, _GENERAL): _handle_product_nonlinear,
287283
# Replace handler from the linear walker
288284
(_LINEAR, _LINEAR): _handle_product_linear_linear,
289-
(_GENERAL, _GENERAL): _handle_product_nonlinear,
290-
(_GENERAL, _LINEAR): _handle_product_nonlinear,
291-
(_LINEAR, _GENERAL): _handle_product_nonlinear,
292285
}
293286
)
294287

295288
#
296289
# DIVISION
297290
#
298291
_exit_node_handlers[DivisionExpression].update(
299-
{
300-
(_CONSTANT, _QUADRATIC): linear._handle_division_nonlinear,
301-
(_LINEAR, _QUADRATIC): linear._handle_division_nonlinear,
302-
(_QUADRATIC, _QUADRATIC): linear._handle_division_nonlinear,
303-
(_GENERAL, _QUADRATIC): linear._handle_division_nonlinear,
304-
(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant,
305-
(_QUADRATIC, _LINEAR): linear._handle_division_nonlinear,
306-
(_QUADRATIC, _GENERAL): linear._handle_division_nonlinear,
307-
}
292+
{(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant}
308293
)
309294

310295

311296
#
312297
# EXPONENTIATION
313298
#
314299
_exit_node_handlers[PowExpression].update(
315-
{
316-
(_CONSTANT, _QUADRATIC): linear._handle_pow_nonlinear,
317-
(_LINEAR, _QUADRATIC): linear._handle_pow_nonlinear,
318-
(_QUADRATIC, _QUADRATIC): linear._handle_pow_nonlinear,
319-
(_GENERAL, _QUADRATIC): linear._handle_pow_nonlinear,
320-
(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant,
321-
(_QUADRATIC, _LINEAR): linear._handle_pow_nonlinear,
322-
(_QUADRATIC, _GENERAL): linear._handle_pow_nonlinear,
323-
}
300+
{(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant}
324301
)
325302

326303
#
327304
# ABS and UNARY handlers
328305
#
329-
_exit_node_handlers[AbsExpression][(_QUADRATIC,)] = linear._handle_unary_nonlinear
330-
_exit_node_handlers[UnaryFunctionExpression][
331-
(_QUADRATIC,)
332-
] = linear._handle_unary_nonlinear
306+
# (no changes needed)
333307

334308
#
335309
# NAMED EXPRESSION handlers
336310
#
337-
_exit_node_handlers[Expression][(_QUADRATIC,)] = linear._handle_named_ANY
311+
# (no changes needed)
338312

339313
#
340314
# EXPR_IF handlers
341315
#
342316
# Note: it is easier to just recreate the entire data structure, rather
343317
# than update it
344-
_exit_node_handlers[Expr_ifExpression] = {
345-
(i, j, k): linear._handle_expr_if_nonlinear
346-
for i in (_LINEAR, _QUADRATIC, _GENERAL)
347-
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
348-
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
349-
}
350-
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
351-
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
352-
_exit_node_handlers[Expr_ifExpression][
353-
_CONSTANT, j, k
354-
] = linear._handle_expr_if_const
355-
356-
#
357-
# RELATIONAL handlers
358-
#
359-
_exit_node_handlers[EqualityExpression].update(
318+
_exit_node_handlers[Expr_ifExpression].update(
360319
{
361-
(_CONSTANT, _QUADRATIC): linear._handle_equality_general,
362-
(_LINEAR, _QUADRATIC): linear._handle_equality_general,
363-
(_QUADRATIC, _QUADRATIC): linear._handle_equality_general,
364-
(_GENERAL, _QUADRATIC): linear._handle_equality_general,
365-
(_QUADRATIC, _CONSTANT): linear._handle_equality_general,
366-
(_QUADRATIC, _LINEAR): linear._handle_equality_general,
367-
(_QUADRATIC, _GENERAL): linear._handle_equality_general,
320+
(_CONSTANT, i, _QUADRATIC): linear._handle_expr_if_const
321+
for i in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
368322
}
369323
)
370-
_exit_node_handlers[InequalityExpression].update(
324+
_exit_node_handlers[Expr_ifExpression].update(
371325
{
372-
(_CONSTANT, _QUADRATIC): linear._handle_inequality_general,
373-
(_LINEAR, _QUADRATIC): linear._handle_inequality_general,
374-
(_QUADRATIC, _QUADRATIC): linear._handle_inequality_general,
375-
(_GENERAL, _QUADRATIC): linear._handle_inequality_general,
376-
(_QUADRATIC, _CONSTANT): linear._handle_inequality_general,
377-
(_QUADRATIC, _LINEAR): linear._handle_inequality_general,
378-
(_QUADRATIC, _GENERAL): linear._handle_inequality_general,
379-
}
380-
)
381-
_exit_node_handlers[RangedExpression].update(
382-
{
383-
(_CONSTANT, _QUADRATIC): linear._handle_ranged_general,
384-
(_LINEAR, _QUADRATIC): linear._handle_ranged_general,
385-
(_QUADRATIC, _QUADRATIC): linear._handle_ranged_general,
386-
(_GENERAL, _QUADRATIC): linear._handle_ranged_general,
387-
(_QUADRATIC, _CONSTANT): linear._handle_ranged_general,
388-
(_QUADRATIC, _LINEAR): linear._handle_ranged_general,
389-
(_QUADRATIC, _GENERAL): linear._handle_ranged_general,
326+
(_CONSTANT, _QUADRATIC, i): linear._handle_expr_if_const
327+
for i in (_CONSTANT, _LINEAR, _GENERAL)
390328
}
391329
)
392330

331+
#
332+
# RELATIONAL handlers
333+
#
334+
# (no changes needed)
335+
393336

394337
class QuadraticRepnVisitor(linear.LinearRepnVisitor):
395338
Result = QuadraticRepn

Diff for: pyomo/repn/tests/test_util.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -718,16 +718,14 @@ class UnknownExpression(NumericExpression):
718718
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
719719
):
720720
end[node.__class__](None, node, *node.args)
721-
self.assertEqual(len(end), 9)
722-
self.assertIn(UnknownExpression, end)
721+
self.assertEqual(len(end), 8)
723722

724723
node = UnknownExpression((6, 7))
725724
with self.assertRaisesRegex(
726725
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
727726
):
728727
end[node.__class__, 6, 7](None, node, *node.args)
729-
self.assertEqual(len(end), 10)
730-
self.assertIn((UnknownExpression, 6, 7), end)
728+
self.assertEqual(len(end), 8)
731729

732730
def test_BeforeChildDispatcher_registration(self):
733731
class BeforeChildDispatcherTester(BeforeChildDispatcher):

0 commit comments

Comments
 (0)