Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Custom PyROS Domain Validators #3169

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 20 additions & 57 deletions pyomo/contrib/pyros/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,71 +26,34 @@
default_pyros_solver_logger = setup_pyros_logger()


class LoggerType:
def logger_domain(obj):
"""
Domain validator for objects castable to logging.Logger.
"""

def __call__(self, obj):
"""
Cast object to logger.
Domain validator for logger-type arguments.

Parameters
----------
obj : object
Object to be cast.
This admits any object of type ``logging.Logger``,
or which can be cast to ``logging.Logger``.
"""
if isinstance(obj, logging.Logger):
return obj
else:
return logging.getLogger(obj)

Returns
-------
logging.Logger
If `str_or_logger` is of type `logging.Logger`,then
`str_or_logger` is returned.
Otherwise, ``logging.getLogger(str_or_logger)``
is returned.
"""
if isinstance(obj, logging.Logger):
return obj
else:
return logging.getLogger(obj)

def domain_name(self):
"""Return str briefly describing domain encompassed by self."""
return "None, str or logging.Logger"
logger_domain.domain_name = "None, str or logging.Logger"


class PositiveIntOrMinusOne:
def positive_int_or_minus_one(obj):
"""
Domain validator for objects castable to a
strictly positive int or -1.
Domain validator for objects castable to a strictly
positive int or -1.
"""
ans = int(obj)
if ans != float(obj) or (ans <= 0 and ans != -1):
raise ValueError(f"Expected positive int or -1, but received value {obj!r}")
return ans

def __call__(self, obj):
"""
Cast object to positive int or -1.

Parameters
----------
obj : object
Object of interest.

Returns
-------
int
Positive int, or -1.

Raises
------
ValueError
If object not castable to positive int, or -1.
"""
ans = int(obj)
if ans != float(obj) or (ans <= 0 and ans != -1):
raise ValueError(f"Expected positive int or -1, but received value {obj!r}")
return ans

def domain_name(self):
"""Return str briefly describing domain encompassed by self."""
return "positive int or -1"
positive_int_or_minus_one.domain_name = "positive int or -1"


def mutable_param_validator(param_obj):
Expand Down Expand Up @@ -721,7 +684,7 @@ def pyros_config():
"max_iter",
ConfigValue(
default=-1,
domain=PositiveIntOrMinusOne(),
domain=positive_int_or_minus_one,
description=(
"""
Iteration limit. If -1 is provided, then no iteration
Expand Down Expand Up @@ -766,7 +729,7 @@ def pyros_config():
"progress_logger",
ConfigValue(
default=default_pyros_solver_logger,
domain=LoggerType(),
domain=logger_domain,
doc=(
"""
Logger (or name thereof) used for reporting PyROS solver
Expand Down
26 changes: 16 additions & 10 deletions pyomo/contrib/pyros/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from pyomo.contrib.pyros.config import (
InputDataStandardizer,
mutable_param_validator,
LoggerType,
logger_domain,
SolverNotResolvable,
PositiveIntOrMinusOne,
positive_int_or_minus_one,
pyros_config,
SolverIterable,
SolverResolvable,
Expand Down Expand Up @@ -557,16 +557,22 @@ def test_positive_int_or_minus_one(self):
"""
Test positive int or -1 validator works as expected.
"""
standardizer_func = PositiveIntOrMinusOne()
standardizer_func = positive_int_or_minus_one
self.assertIs(
standardizer_func(1.0),
1,
msg=(f"{PositiveIntOrMinusOne.__name__} does not standardize as expected."),
msg=(
f"{positive_int_or_minus_one.__name__} "
"does not standardize as expected."
),
)
self.assertEqual(
standardizer_func(-1.00),
-1,
msg=(f"{PositiveIntOrMinusOne.__name__} does not standardize as expected."),
msg=(
f"{positive_int_or_minus_one.__name__} "
"does not standardize as expected."
),
)

exc_str = r"Expected positive int or -1, but received value.*"
Expand All @@ -576,26 +582,26 @@ def test_positive_int_or_minus_one(self):
standardizer_func(0)


class TestLoggerType(unittest.TestCase):
class TestLoggerDomain(unittest.TestCase):
"""
Test logger type validator.
Test logger type domain validator.
"""

def test_logger_type(self):
"""
Test logger type validator.
"""
standardizer_func = LoggerType()
standardizer_func = logger_domain
mylogger = logging.getLogger("example")
self.assertIs(
standardizer_func(mylogger),
mylogger,
msg=f"{LoggerType.__name__} output not as expected",
msg=f"{standardizer_func.__name__} output not as expected",
)
self.assertIs(
standardizer_func(mylogger.name),
mylogger,
msg=f"{LoggerType.__name__} output not as expected",
msg=f"{standardizer_func.__name__} output not as expected",
)

exc_str = r"A logger name must be a string"
Expand Down