Skip to content

Support "default" dispatchers in ExitNodeDispatcher #3194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 23 additions & 55 deletions pyomo/repn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def _handle_negation_ANY(visitor, node, arg):


_exit_node_handlers[NegationExpression] = {
None: _handle_negation_ANY,
(_CONSTANT,): _handle_negation_constant,
(_LINEAR,): _handle_negation_ANY,
(_GENERAL,): _handle_negation_ANY,
}

#
Expand All @@ -211,20 +210,18 @@ def _handle_negation_ANY(visitor, node, arg):


def _handle_product_constant_constant(visitor, node, arg1, arg2):
_, arg1 = arg1
_, arg2 = arg2
ans = arg1 * arg2
ans = arg1[1] * arg2[1]
if ans != ans:
if not arg1 or not arg2:
if not arg1[1] or not arg2[1]:
deprecation_warning(
f"Encountered {str(arg1)}*{str(arg2)} in expression tree. "
f"Encountered {str(arg1[1])}*{str(arg2[1])} in expression tree. "
"Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return _, 0
return _, arg1 * arg2
return _CONSTANT, 0
return _CONSTANT, ans


def _handle_product_constant_ANY(visitor, node, arg1, arg2):
Expand Down Expand Up @@ -276,15 +273,12 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[ProductExpression] = {
None: _handle_product_nonlinear,
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
(_LINEAR, _LINEAR): _handle_product_nonlinear,
(_LINEAR, _GENERAL): _handle_product_nonlinear,
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
(_GENERAL, _LINEAR): _handle_product_nonlinear,
(_GENERAL, _GENERAL): _handle_product_nonlinear,
}
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]

Expand All @@ -309,24 +303,18 @@ def _handle_division_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[DivisionExpression] = {
None: _handle_division_nonlinear,
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
(_CONSTANT, _LINEAR): _handle_division_nonlinear,
(_CONSTANT, _GENERAL): _handle_division_nonlinear,
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
(_LINEAR, _LINEAR): _handle_division_nonlinear,
(_LINEAR, _GENERAL): _handle_division_nonlinear,
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
(_GENERAL, _LINEAR): _handle_division_nonlinear,
(_GENERAL, _GENERAL): _handle_division_nonlinear,
}

#
# EXPONENTIATION handlers
#


def _handle_pow_constant_constant(visitor, node, *args):
arg1, arg2 = args
def _handle_pow_constant_constant(visitor, node, arg1, arg2):
ans = apply_node_operation(node, (arg1[1], arg2[1]))
if ans.__class__ in native_complex_types:
ans = complex_number_error(ans, visitor, node)
Expand Down Expand Up @@ -358,15 +346,10 @@ def _handle_pow_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[PowExpression] = {
None: _handle_pow_nonlinear,
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
(_CONSTANT, _LINEAR): _handle_pow_nonlinear,
(_CONSTANT, _GENERAL): _handle_pow_nonlinear,
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
(_LINEAR, _LINEAR): _handle_pow_nonlinear,
(_LINEAR, _GENERAL): _handle_pow_nonlinear,
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
(_GENERAL, _LINEAR): _handle_pow_nonlinear,
(_GENERAL, _GENERAL): _handle_pow_nonlinear,
}

#
Expand All @@ -389,9 +372,8 @@ def _handle_unary_nonlinear(visitor, node, arg):


_exit_node_handlers[UnaryFunctionExpression] = {
None: _handle_unary_nonlinear,
(_CONSTANT,): _handle_unary_constant,
(_LINEAR,): _handle_unary_nonlinear,
(_GENERAL,): _handle_unary_nonlinear,
}
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]

Expand All @@ -414,9 +396,8 @@ def _handle_named_ANY(visitor, node, arg1):


_exit_node_handlers[Expression] = {
None: _handle_named_ANY,
(_CONSTANT,): _handle_named_constant,
(_LINEAR,): _handle_named_ANY,
(_GENERAL,): _handle_named_ANY,
}

#
Expand Down Expand Up @@ -449,12 +430,7 @@ def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3):
return _GENERAL, ans


_exit_node_handlers[Expr_ifExpression] = {
(i, j, k): _handle_expr_if_nonlinear
for i in (_LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
for k in (_CONSTANT, _LINEAR, _GENERAL)
}
_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
for j in (_CONSTANT, _LINEAR, _GENERAL):
for k in (_CONSTANT, _LINEAR, _GENERAL):
_exit_node_handlers[Expr_ifExpression][_CONSTANT, j, k] = _handle_expr_if_const
Expand Down Expand Up @@ -487,11 +463,9 @@ def _handle_equality_general(visitor, node, arg1, arg2):


_exit_node_handlers[EqualityExpression] = {
(i, j): _handle_equality_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_equality_general,
(_CONSTANT, _CONSTANT): _handle_equality_const,
}
_exit_node_handlers[EqualityExpression][_CONSTANT, _CONSTANT] = _handle_equality_const


def _handle_inequality_const(visitor, node, arg1, arg2):
Expand All @@ -517,13 +491,9 @@ def _handle_inequality_general(visitor, node, arg1, arg2):


_exit_node_handlers[InequalityExpression] = {
(i, j): _handle_inequality_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_inequality_general,
(_CONSTANT, _CONSTANT): _handle_inequality_const,
}
_exit_node_handlers[InequalityExpression][
_CONSTANT, _CONSTANT
] = _handle_inequality_const


def _handle_ranged_const(visitor, node, arg1, arg2, arg3):
Expand Down Expand Up @@ -554,14 +524,9 @@ def _handle_ranged_general(visitor, node, arg1, arg2, arg3):


_exit_node_handlers[RangedExpression] = {
(i, j, k): _handle_ranged_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
for k in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_ranged_general,
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
}
_exit_node_handlers[RangedExpression][
_CONSTANT, _CONSTANT, _CONSTANT
] = _handle_ranged_const


class LinearBeforeChildDispatcher(BeforeChildDispatcher):
Expand Down Expand Up @@ -754,7 +719,10 @@ def _initialize_exit_node_dispatcher(exit_handlers):
exit_dispatcher = {}
for cls, handlers in exit_handlers.items():
for args, fcn in handlers.items():
exit_dispatcher[(cls, *args)] = fcn
if args is None:
exit_dispatcher[cls] = fcn
else:
exit_dispatcher[(cls, *args)] = fcn
return exit_dispatcher


Expand Down
89 changes: 16 additions & 73 deletions pyomo/repn/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,119 +277,62 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):

_exit_node_handlers[ProductExpression].update(
{
None: _handle_product_nonlinear,
(_CONSTANT, _QUADRATIC): linear._handle_product_constant_ANY,
(_LINEAR, _QUADRATIC): _handle_product_nonlinear,
(_QUADRATIC, _QUADRATIC): _handle_product_nonlinear,
(_GENERAL, _QUADRATIC): _handle_product_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_product_ANY_constant,
(_QUADRATIC, _LINEAR): _handle_product_nonlinear,
(_QUADRATIC, _GENERAL): _handle_product_nonlinear,
# Replace handler from the linear walker
(_LINEAR, _LINEAR): _handle_product_linear_linear,
(_GENERAL, _GENERAL): _handle_product_nonlinear,
(_GENERAL, _LINEAR): _handle_product_nonlinear,
(_LINEAR, _GENERAL): _handle_product_nonlinear,
}
)

#
# DIVISION
#
_exit_node_handlers[DivisionExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_division_nonlinear,
(_LINEAR, _QUADRATIC): linear._handle_division_nonlinear,
(_QUADRATIC, _QUADRATIC): linear._handle_division_nonlinear,
(_GENERAL, _QUADRATIC): linear._handle_division_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant,
(_QUADRATIC, _LINEAR): linear._handle_division_nonlinear,
(_QUADRATIC, _GENERAL): linear._handle_division_nonlinear,
}
{(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant}
)


#
# EXPONENTIATION
#
_exit_node_handlers[PowExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_pow_nonlinear,
(_LINEAR, _QUADRATIC): linear._handle_pow_nonlinear,
(_QUADRATIC, _QUADRATIC): linear._handle_pow_nonlinear,
(_GENERAL, _QUADRATIC): linear._handle_pow_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant,
(_QUADRATIC, _LINEAR): linear._handle_pow_nonlinear,
(_QUADRATIC, _GENERAL): linear._handle_pow_nonlinear,
}
{(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant}
)

#
# ABS and UNARY handlers
#
_exit_node_handlers[AbsExpression][(_QUADRATIC,)] = linear._handle_unary_nonlinear
_exit_node_handlers[UnaryFunctionExpression][
(_QUADRATIC,)
] = linear._handle_unary_nonlinear
# (no changes needed)

#
# NAMED EXPRESSION handlers
#
_exit_node_handlers[Expression][(_QUADRATIC,)] = linear._handle_named_ANY
# (no changes needed)

#
# EXPR_IF handlers
#
# Note: it is easier to just recreate the entire data structure, rather
# than update it
_exit_node_handlers[Expr_ifExpression] = {
(i, j, k): linear._handle_expr_if_nonlinear
for i in (_LINEAR, _QUADRATIC, _GENERAL)
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
}
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
_exit_node_handlers[Expr_ifExpression][
_CONSTANT, j, k
] = linear._handle_expr_if_const

#
# RELATIONAL handlers
#
_exit_node_handlers[EqualityExpression].update(
_exit_node_handlers[Expr_ifExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_equality_general,
(_LINEAR, _QUADRATIC): linear._handle_equality_general,
(_QUADRATIC, _QUADRATIC): linear._handle_equality_general,
(_GENERAL, _QUADRATIC): linear._handle_equality_general,
(_QUADRATIC, _CONSTANT): linear._handle_equality_general,
(_QUADRATIC, _LINEAR): linear._handle_equality_general,
(_QUADRATIC, _GENERAL): linear._handle_equality_general,
(_CONSTANT, i, _QUADRATIC): linear._handle_expr_if_const
for i in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
}
)
_exit_node_handlers[InequalityExpression].update(
_exit_node_handlers[Expr_ifExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_inequality_general,
(_LINEAR, _QUADRATIC): linear._handle_inequality_general,
(_QUADRATIC, _QUADRATIC): linear._handle_inequality_general,
(_GENERAL, _QUADRATIC): linear._handle_inequality_general,
(_QUADRATIC, _CONSTANT): linear._handle_inequality_general,
(_QUADRATIC, _LINEAR): linear._handle_inequality_general,
(_QUADRATIC, _GENERAL): linear._handle_inequality_general,
}
)
_exit_node_handlers[RangedExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_ranged_general,
(_LINEAR, _QUADRATIC): linear._handle_ranged_general,
(_QUADRATIC, _QUADRATIC): linear._handle_ranged_general,
(_GENERAL, _QUADRATIC): linear._handle_ranged_general,
(_QUADRATIC, _CONSTANT): linear._handle_ranged_general,
(_QUADRATIC, _LINEAR): linear._handle_ranged_general,
(_QUADRATIC, _GENERAL): linear._handle_ranged_general,
(_CONSTANT, _QUADRATIC, i): linear._handle_expr_if_const
for i in (_CONSTANT, _LINEAR, _GENERAL)
}
)

#
# RELATIONAL handlers
#
# (no changes needed)


class QuadraticRepnVisitor(linear.LinearRepnVisitor):
Result = QuadraticRepn
Expand Down
6 changes: 2 additions & 4 deletions pyomo/repn/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,16 +718,14 @@ class UnknownExpression(NumericExpression):
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 9)
self.assertIn(UnknownExpression, end)
self.assertEqual(len(end), 8)

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__, 6, 7](None, node, *node.args)
self.assertEqual(len(end), 10)
self.assertIn((UnknownExpression, 6, 7), end)
self.assertEqual(len(end), 8)

def test_BeforeChildDispatcher_registration(self):
class BeforeChildDispatcherTester(BeforeChildDispatcher):
Expand Down
Loading