Skip to content

Commit 8883d75

Browse files
authored
Merge pull request #3497 from jsiirola/capture_output_jupyter
Resolve errors in `TeeStream` and `capture_output`
2 parents 093efa2 + bc2bdbd commit 8883d75

File tree

2 files changed

+176
-45
lines changed

2 files changed

+176
-45
lines changed

pyomo/common/tee.py

+101-34
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ class capture_output(object):
194194
195195
capture_fd : bool
196196
197-
If True, we will also redirect the low-level file descriptors
198-
associated with stdout (1) and stderr (2) to the ``output``.
199-
This is useful for capturing output emitted directly to the
200-
process stdout / stderr by external compiled modules.
197+
If True, we will also redirect the process file descriptors
198+
``1`` (stdout), ``2`` (stderr), and the file descriptors from
199+
``sys.stdout.fileno()`` and ``sys.stderr.fileno()`` to the
200+
``output``. This is useful for capturing output emitted
201+
directly to the process stdout / stderr by external compiled
202+
modules.
201203
202204
Returns
203205
-------
@@ -231,19 +233,59 @@ def __enter__(self):
231233
sys.stdout = self.tee.STDOUT
232234
sys.stderr = self.tee.STDERR
233235
if self.capture_fd:
234-
self.fd_redirect = (
235-
redirect_fd(1, self.tee.STDOUT.fileno(), synchronize=False),
236-
redirect_fd(2, self.tee.STDERR.fileno(), synchronize=False),
236+
tee_fd = (self.tee.STDOUT.fileno(), self.tee.STDERR.fileno())
237+
self.fd_redirect = []
238+
for i in range(2):
239+
# Redirect the standard process file descriptor (1 or 2)
240+
self.fd_redirect.append(
241+
redirect_fd(i + 1, tee_fd[i], synchronize=False)
242+
)
243+
# Redirect the file descriptor currently associated with
244+
# sys.stdout / sys.stderr
245+
try:
246+
fd = self.old[i].fileno()
247+
except (AttributeError, OSError):
248+
pass
249+
else:
250+
if fd != i + 1:
251+
self.fd_redirect.append(
252+
redirect_fd(fd, tee_fd[i], synchronize=False)
253+
)
254+
for fdr in self.fd_redirect:
255+
fdr.__enter__()
256+
# We have an issue where we are (very aggressively)
257+
# commandeering the terminal. This is what we intend, but the
258+
# side effect is that any errors generated by this module (e.g.,
259+
# because the user gave us an invalid output stream) get
260+
# completely suppressed. So, we will make an exception to the
261+
# output that we are catching and let messages logged to THIS
262+
# logger to still be emitted.
263+
if self.capture_fd:
264+
# Because we are also comandeering the FD that underlies
265+
# self.old[1], we cannot just write to that stream and
266+
# instead open a new stream to the original FD.
267+
#
268+
# Note that we need to duplicate the FD from the redirector,
269+
# as it will close the (temporary) `original_fd` descriptor
270+
# when it restores the actual original descriptor
271+
self.temp_log_stream = os.fdopen(
272+
os.dup(self.fd_redirect[-1].original_fd), mode="w", closefd=True
237273
)
238-
self.fd_redirect[0].__enter__()
239-
self.fd_redirect[1].__enter__()
274+
else:
275+
self.temp_log_stream = self.old[1]
276+
self.temp_log_handler = logging.StreamHandler(self.temp_log_stream)
277+
logger.addHandler(self.temp_log_handler)
278+
self._propagate = logger.propagate
279+
logger.propagate = False
240280
return self.output_stream
241281

242282
def __exit__(self, et, ev, tb):
283+
# Restore any file descriptors we comandeered
243284
if self.fd_redirect is not None:
244-
self.fd_redirect[1].__exit__(et, ev, tb)
245-
self.fd_redirect[0].__exit__(et, ev, tb)
285+
for fdr in reversed(self.fd_redirect):
286+
fdr.__exit__(et, ev, tb)
246287
self.fd_redirect = None
288+
# Check and restore sys.stderr / sys.stdout
247289
FAIL = self.tee.STDOUT is not sys.stdout
248290
self.tee.__exit__(et, ev, tb)
249291
if self.output_stream is not self.output:
@@ -252,6 +294,15 @@ def __exit__(self, et, ev, tb):
252294
self.old = None
253295
self.tee = None
254296
self.output_stream = None
297+
# Clean up our temporary override of the local logger
298+
self.temp_log_handler.flush()
299+
logger.removeHandler(self.temp_log_handler)
300+
if self.capture_fd:
301+
self.temp_log_stream.flush()
302+
self.temp_log_stream.close()
303+
logger.propagate = self._propagate
304+
self.temp_log_stream = None
305+
self.temp_log_handler = None
255306
if FAIL:
256307
raise RuntimeError('Captured output does not match sys.stdout.')
257308

@@ -378,47 +429,66 @@ def writeOutputBuffer(self, ostreams, flush):
378429
if not ostring:
379430
return
380431

381-
for local_stream, user_stream in ostreams:
432+
for stream in ostreams:
382433
try:
383-
written = local_stream.write(ostring)
434+
written = stream.write(ostring)
384435
except:
385-
written = 0
436+
my_repr = "<%s.%s @ %s>" % (
437+
stream.__class__.__module__,
438+
stream.__class__.__name__,
439+
hex(id(stream)),
440+
)
441+
if my_repr in ostring:
442+
# In the case of nested capture_outputs, all the
443+
# handlers are left on the logger. We want to make
444+
# sure that we don't create an infinite loop by
445+
# re-printing a message *this* object generated.
446+
continue
447+
et, e, tb = sys.exc_info()
448+
msg = "Error writing to output stream %s:\n %s: %s\n" % (
449+
my_repr,
450+
et.__name__,
451+
e,
452+
)
453+
if getattr(stream, 'closed', False):
454+
msg += "Output stream closed before all output was written to it."
455+
else:
456+
msg += "Is this a writeable TextIOBase object?"
457+
logger.error(
458+
f"{msg}\nThe following was left in the output buffer:\n"
459+
f" {ostring!r}"
460+
)
461+
continue
386462
if flush or (written and not self.buffering):
387-
local_stream.flush()
388-
if local_stream is not user_stream:
389-
user_stream.flush()
463+
stream.flush()
390464
# Note: some derived file-like objects fail to return the
391465
# number of characters written (and implicitly return None).
392466
# If we get None, we will just assume that everything was
393467
# fine (as opposed to tossing the incomplete write error).
394468
if written is not None and written != len(ostring):
469+
my_repr = "<%s.%s @ %s>" % (
470+
stream.__class__.__module__,
471+
stream.__class__.__name__,
472+
hex(id(stream)),
473+
)
474+
if my_repr in ostring[written:]:
475+
continue
395476
logger.error(
396-
"Output stream (%s) closed before all output was "
397-
"written to it. The following was left in "
398-
"the output buffer:\n\t%r" % (local_stream, ostring[written:])
477+
"Incomplete write to output stream %s.\nThe following was "
478+
"left in the output buffer:\n %r" % (my_repr, ostring[written:])
399479
)
400480

401481

402482
class TeeStream(object):
403483
def __init__(self, *ostreams, encoding=None, buffering=-1):
404-
self.ostreams = []
484+
self.ostreams = ostreams
405485
self.encoding = encoding
406486
self.buffering = buffering
407487
self._stdout = None
408488
self._stderr = None
409489
self._handles = []
410490
self._active_handles = []
411491
self._threads = []
412-
for user_stream in ostreams:
413-
try:
414-
fileno = user_stream.fileno()
415-
except:
416-
self.ostreams.append((user_stream, user_stream))
417-
continue
418-
local_stream = os.fdopen(
419-
os.dup(fileno), mode=getattr(user_stream, 'mode', None), closefd=True
420-
)
421-
self.ostreams.append((local_stream, user_stream))
422492

423493
@property
424494
def STDOUT(self):
@@ -499,9 +569,6 @@ def close(self, in_exception=False):
499569
self._active_handles.clear()
500570
self._stdout = None
501571
self._stderr = None
502-
for local, orig in self.ostreams:
503-
if orig is not local:
504-
local.close()
505572

506573
def __enter__(self):
507574
return self

pyomo/common/tests/test_tee.py

+75-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import gc
1414
import itertools
15+
import logging
1516
import os
1617
import platform
1718
import time
@@ -35,7 +36,10 @@ def __init__(self):
3536

3637
def write(self, data):
3738
for line in data.splitlines():
38-
self.buf.append((time.time(), float(line.strip())))
39+
line = line.strip()
40+
if not line:
41+
continue
42+
self.buf.append((time.time(), float(line)))
3943

4044
def writelines(self, data):
4145
for line in data:
@@ -216,10 +220,43 @@ def test_decoder_and_buffer_errors(self):
216220
with tee.TeeStream(out) as t:
217221
out.close()
218222
t.STDOUT.write("hi\n")
223+
_id = hex(id(out))
219224
self.assertRegex(
220225
log.getvalue(),
221-
r"^Output stream \(<.*?>\) closed before all output was written "
222-
r"to it. The following was left in the output buffer:\n\t'hi\\n'\n$",
226+
f"Error writing to output stream <_?io.StringIO @ {_id}>:"
227+
r"\n.*\nOutput stream closed before all output was written to it.\n"
228+
r"The following was left in the output buffer:\n 'hi\\n'\n$",
229+
)
230+
231+
# TeeStream expects stream-like objects
232+
out = logging.getLogger()
233+
log = StringIO()
234+
with LoggingIntercept(log):
235+
with tee.TeeStream(out) as t:
236+
t.STDOUT.write("hi\n")
237+
_id = hex(id(out))
238+
self.assertRegex(
239+
log.getvalue(),
240+
f"Error writing to output stream <logging.RootLogger @ {_id}>:"
241+
r"\n.*\nIs this a writeable TextIOBase object\?\n"
242+
r"The following was left in the output buffer:\n 'hi\\n'\n$",
243+
)
244+
245+
# Catch partial writes
246+
class fake_stream:
247+
def write(self, data):
248+
return 1
249+
250+
out = fake_stream()
251+
log = StringIO()
252+
with LoggingIntercept(log):
253+
with tee.TeeStream(out) as t:
254+
t.STDOUT.write("hi\n")
255+
_id = hex(id(out))
256+
self.assertRegex(
257+
log.getvalue(),
258+
f"Incomplete write to output stream <.*fake_stream @ {_id}>."
259+
r"\nThe following was left in the output buffer:\n 'i\\n'\n$",
223260
)
224261

225262
def test_capture_output(self):
@@ -267,6 +304,32 @@ def test_capture_output_stack_error(self):
267304
finally:
268305
sys.stdout, sys.stderr = old
269306

307+
def test_capture_output_invalid_ostream(self):
308+
# Test that capture_output does not suppress errors from the tee
309+
# module
310+
_id = hex(id(15))
311+
with tee.capture_output(capture_fd=True) as OUT:
312+
with tee.capture_output(15):
313+
sys.stderr.write("hi\n")
314+
self.assertEqual(
315+
OUT.getvalue(),
316+
f"Error writing to output stream <builtins.int @ {_id}>:\n"
317+
" AttributeError: 'int' object has no attribute 'write'\n"
318+
"Is this a writeable TextIOBase object?\n"
319+
"The following was left in the output buffer:\n 'hi\\n'\n",
320+
)
321+
322+
with tee.capture_output(capture_fd=True) as OUT:
323+
with tee.capture_output(15, capture_fd=True):
324+
print("hi")
325+
self.assertEqual(
326+
OUT.getvalue(),
327+
f"Error writing to output stream <builtins.int @ {_id}>:\n"
328+
" AttributeError: 'int' object has no attribute 'write'\n"
329+
"Is this a writeable TextIOBase object?\n"
330+
"The following was left in the output buffer:\n 'hi\\n'\n",
331+
)
332+
270333
def test_deadlock(self):
271334
class MockStream(object):
272335
def write(self, data):
@@ -323,14 +386,15 @@ def test_buffered_stdout(self):
323386
sys.stdout.write(f"{time.time()}\n")
324387
time.sleep(self.dt)
325388
ts.write(f"{time.time()}\n")
326-
baseline = [[(0, 0), (1, 0), (1, 0), (1, 1)]]
327-
if fd:
328-
# TODO: [JDS] If we are capturing the file descriptor, the
329-
# stdout channel is sometimes no longer buffered. I am not
330-
# exactly sure why (my guess is because the underlying pipe
331-
# is not buffered), but as it is generally not a problem to
332-
# not buffer, we will put off "fixing" it.
333-
baseline.append([(0, 0), (0, 0), (0, 0), (1, 1)])
389+
baseline = [
390+
[(0, 0), (1, 0), (1, 0), (1, 1)],
391+
# TODO: [JDS] The stdout channel appears to sometimes be no
392+
# longer buffered. I am not exactly sure why (my guess is
393+
# because the underlying pipe is not buffered), but as it is
394+
# generally not a problem to not buffer, we will put off
395+
# "fixing" it.
396+
[(0, 0), (0, 0), (0, 0), (1, 1)],
397+
]
334398
if not ts.check(*baseline):
335399
self.fail(ts.error)
336400

0 commit comments

Comments
 (0)