Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple dispatch for relational expression generation #3483

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b7226e8
Move ARG_TYPE and base dispatcher utilities to expr_common
jsiirola Feb 20, 2025
8a93fbd
Use multiple dispatch for relational expression generation
jsiirola Feb 20, 2025
ad5985c
Update tests to reflect more consistent relational arg processing
jsiirola Feb 20, 2025
a76abac
Update imports to import from actual source
jsiirola Feb 20, 2025
a7c0908
Merge branch 'main' into relational-multiple-dispatch
jsiirola Feb 21, 2025
05297ed
Merge branch 'main' into relational-multiple-dispatch
jsiirola Feb 25, 2025
f78bc94
Merge branch 'main' into relational-multiple-dispatch
jsiirola Mar 2, 2025
e858b8f
Merge branch 'main' into relational-multiple-dispatch
emma58 Mar 3, 2025
7779930
Merge branch 'main' into relational-multiple-dispatch
jsiirola Mar 8, 2025
8d1e7a5
Merge branch 'main' into relational-multiple-dispatch
jsiirola Mar 13, 2025
a5524e5
Merge branch 'main' into relational-multiple-dispatch
jsiirola Mar 25, 2025
61c433f
Addressing review comments
jsiirola Mar 25, 2025
9efb375
Refactor to support reuse for testing relational expressions
jsiirola Mar 25, 2025
20d6067
Rename to maintain consistency
jsiirola Mar 25, 2025
f7fc053
Merge branch 'main' into relational-multiple-dispatch
jsiirola Apr 1, 2025
99d1be3
Merge branch 'main' into relational-multiple-dispatch
jsiirola Apr 2, 2025
2f5ac1f
NFC: address reviewer comment
jsiirola Apr 2, 2025
c977e5b
Merge remote-tracking branch 'me/numeric_expr-test-driver' into relat…
jsiirola Apr 3, 2025
c2d60e2
Reverse the priority for resolving generic asnumeric/mutable/invalid …
jsiirola Apr 3, 2025
fea56b1
Update docs / fix grammar typo
jsiirola Apr 3, 2025
69f5f4f
Support testing exception message in the dispatch tester
jsiirola Apr 3, 2025
f1bc18f
Add dispatcher tests for equality expressions
jsiirola Apr 3, 2025
2aa4fcf
Add testing for <= dispatcher
jsiirola Apr 3, 2025
353c9a9
Comparing expressions should check Inequality strict-ness
jsiirola Apr 3, 2025
7b636d1
Add strict inequality dispatcher tests
jsiirola Apr 3, 2025
7eec611
Update tests to track change in inequality representation
jsiirola Apr 3, 2025
e1509c2
NFC: apply black
jsiirola Apr 3, 2025
ac5d88d
NFC: fix typo
jsiirola Apr 3, 2025
1f5e1e7
Merge branch 'main' into relational-multiple-dispatch
blnicho Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pyomo/core/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@
)

#
# FIXME: remove circular dependencies between relational_expr and numeric_expr
# FIXME: remove circular dependencies between logical_expr and numeric_expr
#

# Initialize relational expression functions
numeric_expr._generate_relational_expression = (
relational_expr._generate_relational_expression
)

# Initialize logicalvalue functions
boolean_value._generate_logical_proposition = logical_expr._generate_logical_proposition

Expand Down
2 changes: 0 additions & 2 deletions pyomo/core/expr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

import enum

from pyomo.common.dependencies import attempt_import
from pyomo.common.numeric_types import native_types
from pyomo.core.pyomoobject import PyomoObject
Expand Down
2 changes: 1 addition & 1 deletion pyomo/core/expr/boolean_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging

from pyomo.common.deprecation import deprecated
from pyomo.core.expr.numvalue import native_types, native_logical_types
from pyomo.common.numeric_types import native_types, native_logical_types
from pyomo.core.expr.expr_common import _and, _or, _equiv, _inv, _xor, _impl
from pyomo.core.pyomoobject import PyomoObject

Expand Down
110 changes: 106 additions & 4 deletions pyomo/core/expr/expr_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

TO_STRING_VERBOSE = False

_eq = 0
_le = 1
_lt = 2

# logical propositions
_and = 0
_or = 1
Expand Down Expand Up @@ -82,6 +78,112 @@ class ExpressionType(enums.Enum):
LOGICAL = 2


class NUMERIC_ARG_TYPE(enums.IntEnum):
MUTABLE = -2
ASNUMERIC = -1
INVALID = 0
NATIVE = 1
NPV = 2
PARAM = 3
VAR = 4
MONOMIAL = 5
LINEAR = 6
SUM = 7
OTHER = 8


class RELATIONAL_ARG_TYPE(enums.IntEnum, metaclass=enums.ExtendedEnumType):
__base_enum__ = NUMERIC_ARG_TYPE

INEQUALITY = 100
INVALID_RELATIONAL = 101
Comment on lines +96 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have equality in it? Why doesn't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the list of argument types for relational expressions. As an Equality expression is not allowed to be the argument for another numeric / relational expression, we will never need to do special dispatch for it (it will fall back on getting mapped to INVALID).



def _invalid(*args):
return NotImplemented


def _recast_mutable(expr):
expr.make_immutable()
if expr._nargs > 1:
return expr
elif not expr._nargs:
return 0
else:
return expr._args_[0]


def _unary_op_dispatcher_type_mapping(dispatcher, updates, TYPES=NUMERIC_ARG_TYPE):
#
# Special case (wrapping) operators
#
def _asnumeric(a):
a = a.as_numeric()
return dispatcher[a.__class__](a)

def _mutable(a):
a = _recast_mutable(a)
return dispatcher[a.__class__](a)

mapping = {
TYPES.ASNUMERIC: _asnumeric,
TYPES.MUTABLE: _mutable,
TYPES.INVALID: _invalid,
}

mapping.update(updates)
return mapping


def _binary_op_dispatcher_type_mapping(dispatcher, updates, TYPES=NUMERIC_ARG_TYPE):
#
# Special case (wrapping) operators
#
def _any_asnumeric(a, b):
b = b.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _asnumeric_any(a, b):
a = a.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _asnumeric_asnumeric(a, b):
a = a.as_numeric()
b = b.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _any_mutable(a, b):
b = _recast_mutable(b)
return dispatcher[a.__class__, b.__class__](a, b)

def _mutable_any(a, b):
a = _recast_mutable(a)
return dispatcher[a.__class__, b.__class__](a, b)

def _mutable_mutable(a, b):
if a is b:
a = b = _recast_mutable(a)
else:
a = _recast_mutable(a)
b = _recast_mutable(b)
return dispatcher[a.__class__, b.__class__](a, b)

mapping = {}
mapping.update({(i, TYPES.ASNUMERIC): _any_asnumeric for i in TYPES})
mapping.update({(TYPES.ASNUMERIC, i): _asnumeric_any for i in TYPES})
mapping[TYPES.ASNUMERIC, TYPES.ASNUMERIC] = _asnumeric_asnumeric

mapping.update({(i, TYPES.MUTABLE): _any_mutable for i in TYPES})
mapping.update({(TYPES.MUTABLE, i): _mutable_any for i in TYPES})
mapping[TYPES.MUTABLE, TYPES.MUTABLE] = _mutable_mutable

mapping.update({(i, TYPES.INVALID): _invalid for i in TYPES})
mapping.update({(TYPES.INVALID, i): _invalid for i in TYPES})

mapping.update(updates)
return mapping


@deprecated(
"""The clone counter has been removed and will always return 0.

Expand Down
152 changes: 38 additions & 114 deletions pyomo/core/expr/numeric_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

logger = logging.getLogger('pyomo.core')

from pyomo.common import enums
from pyomo.common.dependencies import attempt_import
from pyomo.common.deprecation import deprecated, relocated_module_attribute
from pyomo.common.errors import PyomoException, DeveloperError
Expand All @@ -33,14 +32,24 @@
from pyomo.core.expr.expr_common import (
OperatorAssociativity,
ExpressionType,
_lt,
_le,
_eq,
_unary_op_dispatcher_type_mapping,
_binary_op_dispatcher_type_mapping,
_invalid,
_recast_mutable,
NUMERIC_ARG_TYPE as ARG_TYPE,
)

# Note: pyggyback on expr.base's use of attempt_import(visitor)
from pyomo.core.expr.base import ExpressionBase, NPV_Mixin, visitor

# Note: There is a circular dependency between relational_expr and this
# module: relational_expr would like to reuse/build on
# _categorize_arg_type(), and NumericValue needs to call the relational
# dispatchers from relational_expr. Instead of ensuring that one of the
# modules is fully declared before importing into the other, we will
# have BOTH modules assume that the other module has NOT been declared.
import pyomo.core.expr.relational_expr as relational_expr


_ndarray, _ = attempt_import('pyomo.core.expr.ndarray')

Expand Down Expand Up @@ -92,15 +101,16 @@
version='6.6.2',
f_globals=globals(),
)
relocated_module_attribute(
'register_arg_type',
'pyomo.core.expr.expr_common',
version='6.8.3.dev0',
f_globals=globals(),
)

_zero_one_optimizations = {1}


# Stub in the dispatchers
def _generate_relational_expression(etype, lhs, rhs):
raise RuntimeError("incomplete import of Pyomo expression system")


def enable_expression_optimizations(zero=None, one=None):
"""Enable(disable) expression generation optimizations

Expand Down Expand Up @@ -375,7 +385,9 @@ def __lt__(self, other):
self < other
other > self
"""
return _generate_relational_expression(_lt, self, other)
return relational_expr._lt_dispatcher[self.__class__, other.__class__](
self, other
)

def __gt__(self, other):
"""
Expand All @@ -386,7 +398,9 @@ def __gt__(self, other):
self > other
other < self
"""
return _generate_relational_expression(_lt, other, self)
return relational_expr._lt_dispatcher[other.__class__, self.__class__](
other, self
)

def __le__(self, other):
"""
Expand All @@ -397,7 +411,9 @@ def __le__(self, other):
self <= other
other >= self
"""
return _generate_relational_expression(_le, self, other)
return relational_expr._le_dispatcher[self.__class__, other.__class__](
self, other
)

def __ge__(self, other):
"""
Expand All @@ -408,7 +424,9 @@ def __ge__(self, other):
self >= other
other <= self
"""
return _generate_relational_expression(_le, other, self)
return relational_expr._le_dispatcher[other.__class__, self.__class__](
other, self
)

def __eq__(self, other):
"""
Expand All @@ -418,7 +436,13 @@ def __eq__(self, other):

self == other
"""
return _generate_relational_expression(_eq, self, other)
# Note: While it would appear that keeping the attribute lookup
# into the relational_expr module would be a performance hit, we
# want that indirection as it allows us to selectively disable
# operator overloading for comparisons.
return relational_expr._eq_dispatcher[self.__class__, other.__class__](
self, other
)

def __add__(self, other):
"""
Expand Down Expand Up @@ -1637,21 +1661,6 @@ def _decompose_linear_terms(expr, multiplier=1):
#
# -------------------------------------------------------


class ARG_TYPE(enums.IntEnum):
MUTABLE = -2
ASNUMERIC = -1
INVALID = 0
NATIVE = 1
NPV = 2
PARAM = 3
VAR = 4
MONOMIAL = 5
LINEAR = 6
SUM = 7
OTHER = 8


_known_arg_types = {}


Expand Down Expand Up @@ -1715,91 +1724,6 @@ def _categorize_arg_types(*args):
return tuple(_categorize_arg_type(arg) for arg in args)


def _invalid(*args):
return NotImplemented


def _recast_mutable(expr):
expr.make_immutable()
if expr._nargs > 1:
return expr
elif not expr._nargs:
return 0
else:
return expr._args_[0]


def _unary_op_dispatcher_type_mapping(dispatcher, updates):
#
# Special case (wrapping) operators
#
def _asnumeric(a):
a = a.as_numeric()
return dispatcher[a.__class__](a)

def _mutable(a):
a = _recast_mutable(a)
return dispatcher[a.__class__](a)

mapping = {
ARG_TYPE.ASNUMERIC: _asnumeric,
ARG_TYPE.MUTABLE: _mutable,
ARG_TYPE.INVALID: _invalid,
}

mapping.update(updates)
return mapping


def _binary_op_dispatcher_type_mapping(dispatcher, updates):
#
# Special case (wrapping) operators
#
def _any_asnumeric(a, b):
b = b.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _asnumeric_any(a, b):
a = a.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _asnumeric_asnumeric(a, b):
a = a.as_numeric()
b = b.as_numeric()
return dispatcher[a.__class__, b.__class__](a, b)

def _any_mutable(a, b):
b = _recast_mutable(b)
return dispatcher[a.__class__, b.__class__](a, b)

def _mutable_any(a, b):
a = _recast_mutable(a)
return dispatcher[a.__class__, b.__class__](a, b)

def _mutable_mutable(a, b):
if a is b:
a = b = _recast_mutable(a)
else:
a = _recast_mutable(a)
b = _recast_mutable(b)
return dispatcher[a.__class__, b.__class__](a, b)

mapping = {}
mapping.update({(i, ARG_TYPE.ASNUMERIC): _any_asnumeric for i in ARG_TYPE})
mapping.update({(ARG_TYPE.ASNUMERIC, i): _asnumeric_any for i in ARG_TYPE})
mapping[ARG_TYPE.ASNUMERIC, ARG_TYPE.ASNUMERIC] = _asnumeric_asnumeric

mapping.update({(i, ARG_TYPE.MUTABLE): _any_mutable for i in ARG_TYPE})
mapping.update({(ARG_TYPE.MUTABLE, i): _mutable_any for i in ARG_TYPE})
mapping[ARG_TYPE.MUTABLE, ARG_TYPE.MUTABLE] = _mutable_mutable

mapping.update({(i, ARG_TYPE.INVALID): _invalid for i in ARG_TYPE})
mapping.update({(ARG_TYPE.INVALID, i): _invalid for i in ARG_TYPE})

mapping.update(updates)
return mapping


#
# ADD: NATIVE handlers
#
Expand Down
Loading
Loading