Skip to content

Commit 0b94b39

Browse files
authored
Merge pull request #3169 from shermanjasonaf/update-kwarg-parsing-validation
Simplify Custom PyROS Domain Validators
2 parents e30aea7 + 20a6360 commit 0b94b39

File tree

2 files changed

+46
-70
lines changed

2 files changed

+46
-70
lines changed

pyomo/contrib/pyros/config.py

+20-57
Original file line numberDiff line numberDiff line change
@@ -26,71 +26,34 @@
2626
default_pyros_solver_logger = setup_pyros_logger()
2727

2828

29-
class LoggerType:
29+
def logger_domain(obj):
3030
"""
31-
Domain validator for objects castable to logging.Logger.
32-
"""
33-
34-
def __call__(self, obj):
35-
"""
36-
Cast object to logger.
31+
Domain validator for logger-type arguments.
3732
38-
Parameters
39-
----------
40-
obj : object
41-
Object to be cast.
33+
This admits any object of type ``logging.Logger``,
34+
or which can be cast to ``logging.Logger``.
35+
"""
36+
if isinstance(obj, logging.Logger):
37+
return obj
38+
else:
39+
return logging.getLogger(obj)
4240

43-
Returns
44-
-------
45-
logging.Logger
46-
If `str_or_logger` is of type `logging.Logger`,then
47-
`str_or_logger` is returned.
48-
Otherwise, ``logging.getLogger(str_or_logger)``
49-
is returned.
50-
"""
51-
if isinstance(obj, logging.Logger):
52-
return obj
53-
else:
54-
return logging.getLogger(obj)
5541

56-
def domain_name(self):
57-
"""Return str briefly describing domain encompassed by self."""
58-
return "None, str or logging.Logger"
42+
logger_domain.domain_name = "None, str or logging.Logger"
5943

6044

61-
class PositiveIntOrMinusOne:
45+
def positive_int_or_minus_one(obj):
6246
"""
63-
Domain validator for objects castable to a
64-
strictly positive int or -1.
47+
Domain validator for objects castable to a strictly
48+
positive int or -1.
6549
"""
50+
ans = int(obj)
51+
if ans != float(obj) or (ans <= 0 and ans != -1):
52+
raise ValueError(f"Expected positive int or -1, but received value {obj!r}")
53+
return ans
6654

67-
def __call__(self, obj):
68-
"""
69-
Cast object to positive int or -1.
7055

71-
Parameters
72-
----------
73-
obj : object
74-
Object of interest.
75-
76-
Returns
77-
-------
78-
int
79-
Positive int, or -1.
80-
81-
Raises
82-
------
83-
ValueError
84-
If object not castable to positive int, or -1.
85-
"""
86-
ans = int(obj)
87-
if ans != float(obj) or (ans <= 0 and ans != -1):
88-
raise ValueError(f"Expected positive int or -1, but received value {obj!r}")
89-
return ans
90-
91-
def domain_name(self):
92-
"""Return str briefly describing domain encompassed by self."""
93-
return "positive int or -1"
56+
positive_int_or_minus_one.domain_name = "positive int or -1"
9457

9558

9659
def mutable_param_validator(param_obj):
@@ -721,7 +684,7 @@ def pyros_config():
721684
"max_iter",
722685
ConfigValue(
723686
default=-1,
724-
domain=PositiveIntOrMinusOne(),
687+
domain=positive_int_or_minus_one,
725688
description=(
726689
"""
727690
Iteration limit. If -1 is provided, then no iteration
@@ -766,7 +729,7 @@ def pyros_config():
766729
"progress_logger",
767730
ConfigValue(
768731
default=default_pyros_solver_logger,
769-
domain=LoggerType(),
732+
domain=logger_domain,
770733
doc=(
771734
"""
772735
Logger (or name thereof) used for reporting PyROS solver

pyomo/contrib/pyros/tests/test_config.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from pyomo.contrib.pyros.config import (
1313
InputDataStandardizer,
1414
mutable_param_validator,
15-
LoggerType,
15+
logger_domain,
1616
SolverNotResolvable,
17-
PositiveIntOrMinusOne,
17+
positive_int_or_minus_one,
1818
pyros_config,
1919
SolverIterable,
2020
SolverResolvable,
@@ -557,16 +557,29 @@ def test_positive_int_or_minus_one(self):
557557
"""
558558
Test positive int or -1 validator works as expected.
559559
"""
560-
standardizer_func = PositiveIntOrMinusOne()
561-
self.assertIs(
562-
standardizer_func(1.0),
560+
standardizer_func = positive_int_or_minus_one
561+
ans = standardizer_func(1.0)
562+
self.assertEqual(
563+
ans,
563564
1,
564-
msg=(f"{PositiveIntOrMinusOne.__name__} does not standardize as expected."),
565+
msg=f"{positive_int_or_minus_one.__name__} output value not as expected.",
566+
)
567+
self.assertIs(
568+
type(ans),
569+
int,
570+
msg=f"{positive_int_or_minus_one.__name__} output type not as expected.",
565571
)
572+
573+
ans = standardizer_func(-1.0)
566574
self.assertEqual(
567-
standardizer_func(-1.00),
575+
ans,
568576
-1,
569-
msg=(f"{PositiveIntOrMinusOne.__name__} does not standardize as expected."),
577+
msg=f"{positive_int_or_minus_one.__name__} output value not as expected.",
578+
)
579+
self.assertIs(
580+
type(ans),
581+
int,
582+
msg=f"{positive_int_or_minus_one.__name__} output type not as expected.",
570583
)
571584

572585
exc_str = r"Expected positive int or -1, but received value.*"
@@ -576,26 +589,26 @@ def test_positive_int_or_minus_one(self):
576589
standardizer_func(0)
577590

578591

579-
class TestLoggerType(unittest.TestCase):
592+
class TestLoggerDomain(unittest.TestCase):
580593
"""
581-
Test logger type validator.
594+
Test logger type domain validator.
582595
"""
583596

584597
def test_logger_type(self):
585598
"""
586599
Test logger type validator.
587600
"""
588-
standardizer_func = LoggerType()
601+
standardizer_func = logger_domain
589602
mylogger = logging.getLogger("example")
590603
self.assertIs(
591604
standardizer_func(mylogger),
592605
mylogger,
593-
msg=f"{LoggerType.__name__} output not as expected",
606+
msg=f"{standardizer_func.__name__} output not as expected",
594607
)
595608
self.assertIs(
596609
standardizer_func(mylogger.name),
597610
mylogger,
598-
msg=f"{LoggerType.__name__} output not as expected",
611+
msg=f"{standardizer_func.__name__} output not as expected",
599612
)
600613

601614
exc_str = r"A logger name must be a string"

0 commit comments

Comments
 (0)