Skip to content

Commit d329ee1

Browse files
authored
Merge pull request #3483 from jsiirola/relational-multiple-dispatch
Multiple dispatch for relational expression generation
2 parents 166119a + 1f5e1e7 commit d329ee1

12 files changed

+7795
-251
lines changed

pyomo/core/expr/__init__.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,9 @@
1919
)
2020

2121
#
22-
# FIXME: remove circular dependencies between relational_expr and numeric_expr
22+
# FIXME: remove circular dependencies between logical_expr and numeric_expr
2323
#
2424

25-
# Initialize relational expression functions
26-
numeric_expr._generate_relational_expression = (
27-
relational_expr._generate_relational_expression
28-
)
29-
3025
# Initialize logicalvalue functions
3126
boolean_value._generate_logical_proposition = logical_expr._generate_logical_proposition
3227

pyomo/core/expr/base.py

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12-
import enum
13-
1412
from pyomo.common.dependencies import attempt_import
1513
from pyomo.common.numeric_types import native_types
1614
from pyomo.common.modeling import NOTSET

pyomo/core/expr/boolean_value.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from pyomo.common.deprecation import deprecated
1616
from pyomo.common.modeling import NOTSET
17+
from pyomo.common.numeric_types import native_types, native_logical_types
1718
from pyomo.core.expr.expr_common import _type_check_exception_arg
18-
from pyomo.core.expr.numvalue import native_types, native_logical_types
1919
from pyomo.core.expr.expr_common import _and, _or, _equiv, _inv, _xor, _impl
2020
from pyomo.core.pyomoobject import PyomoObject
2121

pyomo/core/expr/compare.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def handle_sequence(node: collections.abc.Sequence, pn: List):
7171
return list(node)
7272

7373

74+
def handle_inequality(node: collections.abc.Sequence, pn: List):
75+
pn.append((type(node), node.nargs(), node.strict))
76+
return node.args
77+
78+
7479
def _generic_expression_handler():
7580
return handle_expression
7681

@@ -83,7 +88,8 @@ def _generic_expression_handler():
8388
handler[NPV_ExternalFunctionExpression] = handle_external_function_expression
8489
handler[AbsExpression] = handle_unary_expression
8590
handler[NPV_AbsExpression] = handle_unary_expression
86-
handler[RangedExpression] = handle_expression
91+
handler[InequalityExpression] = handle_inequality
92+
handler[RangedExpression] = handle_inequality
8793
handler[list] = handle_sequence
8894

8995

pyomo/core/expr/expr_common.py

+114-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717

1818
TO_STRING_VERBOSE = False
1919

20-
_eq = 0
21-
_le = 1
22-
_lt = 2
23-
2420
# logical propositions
2521
_and = 0
2622
_or = 1
@@ -83,6 +79,120 @@ class ExpressionType(enums.Enum):
8379
LOGICAL = 2
8480

8581

82+
class NUMERIC_ARG_TYPE(enums.IntEnum):
83+
MUTABLE = -2
84+
ASNUMERIC = -1
85+
INVALID = 0
86+
NATIVE = 1
87+
NPV = 2
88+
PARAM = 3
89+
VAR = 4
90+
MONOMIAL = 5
91+
LINEAR = 6
92+
SUM = 7
93+
OTHER = 8
94+
95+
96+
class RELATIONAL_ARG_TYPE(enums.IntEnum, metaclass=enums.ExtendedEnumType):
97+
__base_enum__ = NUMERIC_ARG_TYPE
98+
99+
INEQUALITY = 100
100+
INVALID_RELATIONAL = 101
101+
102+
103+
def _invalid(*args):
104+
return NotImplemented
105+
106+
107+
def _recast_mutable(expr):
108+
expr.make_immutable()
109+
if expr._nargs > 1:
110+
return expr
111+
elif not expr._nargs:
112+
return 0
113+
else:
114+
return expr._args_[0]
115+
116+
117+
def _unary_op_dispatcher_type_mapping(dispatcher, updates, TYPES=NUMERIC_ARG_TYPE):
118+
#
119+
# Special case (wrapping) operators
120+
#
121+
def _asnumeric(a):
122+
a = a.as_numeric()
123+
return dispatcher[a.__class__](a)
124+
125+
def _mutable(a):
126+
a = _recast_mutable(a)
127+
return dispatcher[a.__class__](a)
128+
129+
mapping = {
130+
TYPES.ASNUMERIC: _asnumeric,
131+
TYPES.MUTABLE: _mutable,
132+
TYPES.INVALID: _invalid,
133+
}
134+
135+
mapping.update(updates)
136+
return mapping
137+
138+
139+
def _binary_op_dispatcher_type_mapping(dispatcher, updates, TYPES=NUMERIC_ARG_TYPE):
140+
#
141+
# Special case (wrapping) operators
142+
#
143+
def _any_asnumeric(a, b):
144+
b = b.as_numeric()
145+
return dispatcher[a.__class__, b.__class__](a, b)
146+
147+
def _asnumeric_any(a, b):
148+
a = a.as_numeric()
149+
return dispatcher[a.__class__, b.__class__](a, b)
150+
151+
def _asnumeric_asnumeric(a, b):
152+
a = a.as_numeric()
153+
b = b.as_numeric()
154+
return dispatcher[a.__class__, b.__class__](a, b)
155+
156+
def _any_mutable(a, b):
157+
b = _recast_mutable(b)
158+
return dispatcher[a.__class__, b.__class__](a, b)
159+
160+
def _mutable_any(a, b):
161+
a = _recast_mutable(a)
162+
return dispatcher[a.__class__, b.__class__](a, b)
163+
164+
def _mutable_mutable(a, b):
165+
if a is b:
166+
# Note: _recast_mutable is an in-place operation: make sure
167+
# that we don't call it twice on the same object.
168+
a = b = _recast_mutable(a)
169+
else:
170+
a = _recast_mutable(a)
171+
b = _recast_mutable(b)
172+
return dispatcher[a.__class__, b.__class__](a, b)
173+
174+
mapping = {}
175+
176+
# Because ASNUMERIC and MUTABLE re-call the dispatcher, we want to
177+
# resolve ASNUMERIC first, MUTABLE second, and INVALID last. That
178+
# means we will add them to the dispatcher dict in opposite order so
179+
# "higher priority" callbacks override lower priority ones.
180+
181+
mapping.update({(i, TYPES.INVALID): _invalid for i in TYPES})
182+
mapping.update({(TYPES.INVALID, i): _invalid for i in TYPES})
183+
184+
mapping.update({(i, TYPES.MUTABLE): _any_mutable for i in TYPES})
185+
mapping.update({(TYPES.MUTABLE, i): _mutable_any for i in TYPES})
186+
mapping[TYPES.MUTABLE, TYPES.MUTABLE] = _mutable_mutable
187+
188+
mapping.update({(i, TYPES.ASNUMERIC): _any_asnumeric for i in TYPES})
189+
mapping.update({(TYPES.ASNUMERIC, i): _asnumeric_any for i in TYPES})
190+
mapping[TYPES.ASNUMERIC, TYPES.ASNUMERIC] = _asnumeric_asnumeric
191+
192+
mapping.update(updates)
193+
return mapping
194+
195+
86196
@deprecated(
87197
"""The clone counter has been removed and will always return 0.
88198

0 commit comments

Comments
 (0)