Skip to content

Commit 4acdb69

Browse files
authored
Merge pull request #3482 from mrmundt/ipopt-bugfix
BUGFIX: Validator for `tee` in `contrib.solver`
2 parents 2367e85 + dec5632 commit 4acdb69

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

pyomo/contrib/solver/config.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,45 @@
3131

3232

3333
def TextIO_or_Logger(val):
34-
ans = []
35-
if not isinstance(val, Sequence):
34+
"""
35+
Validates and converts input into a list of valid output streams.
36+
37+
Accepts:
38+
- sys.stdout
39+
- Instances of io.TextIOBase
40+
- logging.Logger (wrapped as LogStream)
41+
- Boolean values (`True` -> sys.stdout)
42+
43+
Returns:
44+
- A list of validated output streams.
45+
46+
Raises:
47+
- ValueError if an invalid type is provided.
48+
"""
49+
if isinstance(val, Sequence) and not isinstance(val, (str, bytes)):
50+
val = list(val)
51+
52+
else:
3653
val = [val]
54+
55+
ans = []
56+
3757
for v in val:
3858
if v.__class__ in native_logical_types:
3959
if v:
4060
ans.append(sys.stdout)
41-
elif isinstance(v, io.TextIOBase):
61+
elif isinstance(v, (sys.stdout.__class__, io.TextIOBase)):
62+
# We are guarding against file-like classes that do not derive from
63+
# TextIOBase but are assigned to stdout / stderr.
64+
# We still want to accept those classes.
4265
ans.append(v)
4366
elif isinstance(v, logging.Logger):
4467
ans.append(LogStream(level=logging.INFO, logger=v))
4568
else:
4669
raise ValueError(
47-
f"Expected bool, TextIOBase, or Logger, but received {v.__class__}"
70+
f"Expected sys.stdout, io.TextIOBase, Logger, or bool, but received {v.__class__}"
4871
)
72+
4973
return ans
5074

5175

pyomo/contrib/solver/tests/unit/test_config.py

+38
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,52 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
import logging
13+
import io
14+
15+
import pyomo.environ as pyo
1216
from pyomo.common import unittest
17+
from pyomo.common.log import LogStream
18+
from pyomo.common.tee import capture_output
19+
from pyomo.common.dependencies import attempt_import
1320
from pyomo.contrib.solver.config import (
1421
SolverConfig,
1522
BranchAndBoundConfig,
1623
AutoUpdateConfig,
1724
PersistentSolverConfig,
25+
TextIO_or_Logger,
1826
)
1927

28+
ipopt, ipopt_available = attempt_import('ipopt')
29+
30+
31+
class TestTextIO_or_LoggerValidator(unittest.TestCase):
32+
def test_booleans(self):
33+
ans = TextIO_or_Logger(True)
34+
self.assertTrue(isinstance(ans[0], io._io.TextIOWrapper))
35+
ans = TextIO_or_Logger(False)
36+
self.assertEqual(ans, [])
37+
38+
def test_logger(self):
39+
logger = logging.getLogger('contrib.solver.config.test.1')
40+
ans = TextIO_or_Logger(logger)
41+
self.assertTrue(isinstance(ans[0], LogStream))
42+
43+
@unittest.skipIf(not ipopt_available, 'ipopt is not available')
44+
def test_real_example(self):
45+
46+
m = pyo.ConcreteModel()
47+
m.x = pyo.Var([1, 2], initialize=1, bounds=(0, None))
48+
m.eq = pyo.Constraint(expr=m.x[1] * m.x[2] ** 1.5 == 3)
49+
m.obj = pyo.Objective(expr=m.x[1] ** 2 + m.x[2] ** 2)
50+
51+
solver = pyo.SolverFactory("ipopt_v2")
52+
with capture_output() as OUT:
53+
solver.solve(m, tee=True, timelimit=5)
54+
55+
contents = OUT.getvalue()
56+
self.assertIn('EXIT: Optimal Solution Found.', contents)
57+
2058

2159
class TestSolverConfig(unittest.TestCase):
2260
def test_interface_default_instantiation(self):

0 commit comments

Comments
 (0)