Skip to content
Merged
63 changes: 62 additions & 1 deletion aeon/testing/utils/output_suppression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,68 @@

@contextmanager
def suppress_output(suppress_stdout=True, suppress_stderr=True):
"""Redirects stdout and/or stderr to devnull."""
"""
Context manager to suppress stdout and/or stderr output.
This function redirects standard output (stdout) and standard error (stderr)
to `devnull`, effectively silencing any print statements or error messages
within its context.
Parameters
----------
suppress_stdout : bool, optional, default=True
If True, redirects stdout to null, suppressing print statements.
suppress_stderr : bool, optional, default=True
If True, redirects stderr to null, suppressing error messages.
Examples
--------
Suppressing both stdout and stderr:
>>> with suppress_output():
... print("This will not be displayed")
... import sys
... print("Error messages will be hidden", file=sys.stderr)
Suppressing only stdout:
>>> import sys
>>> sys.stderr = sys.stdout # Needed so doctest can capture stderr
>>> with suppress_output(suppress_stdout=True, suppress_stderr=False):
... print("This will not be shown")
... print("Error messages will still be visible", file=sys.stderr)
Error messages will still be visible
Suppressing only stderr:
>>> with suppress_output(suppress_stdout=False, suppress_stderr=True):
... print("This will be shown")
... import sys
... print("Error messages will be hidden", file=sys.stderr)
This will be shown
Using as a function wrapper:
Suppressing both stdout and stderr:
>>> @suppress_output()
... def noisy_function():
... print("Noisy output")
... import sys
... print("Noisy error", file=sys.stderr)
>>> noisy_function()
Suppressing only stdout:
>>> import sys
>>> sys.stderr = sys.stdout # Needed so doctest can capture stderr
>>> @suppress_output(suppress_stderr=False)
... def noisy_function():
... print("Noisy output")
... print("Noisy error", file=sys.stderr)
>>> noisy_function()
Noisy error
"""
with open(devnull, "w") as null:
stdout = sys.stdout
stderr = sys.stderr
Expand Down
51 changes: 42 additions & 9 deletions aeon/testing/utils/tests/test_output_supression.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,55 @@
"""Test output suppression decorator."""

import io
import sys

from aeon.testing.utils.output_suppression import suppress_output


@suppress_output()
def test_suppress_output():
"""Test suppress_output method with True inputs."""
print( # noqa: T201
"Hello world! If this is visible suppress_output is not working!"
)
print( # noqa: T201
"Error! If this is visible suppress_output is not working!", file=sys.stderr
)

@suppress_output()
def inner_test():

print( # noqa: T201
"Hello world! If this is visible suppress_output is not working!"
)
print( # noqa: T201
"Error! If this is visible suppress_output is not working!", file=sys.stderr
)

stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
sys.stdout = stdout_capture
sys.stderr = stderr_capture

inner_test()

assert stdout_capture.getvalue() == "", "stdout was not suppressed!"
assert stderr_capture.getvalue() == "", "stderr was not suppressed!"


@suppress_output(suppress_stdout=False, suppress_stderr=False)
def test_suppress_output_false():
"""Test suppress_output method with False inputs."""
pass

@suppress_output(suppress_stdout=False, suppress_stderr=False)
def inner_test():
print("This should be visible.") # noqa: T201
print( # noqa: T201
"This error message should also be visible.", file=sys.stderr
)

stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
sys.stdout = stdout_capture
sys.stderr = stderr_capture

inner_test()

assert ( # noqa: T201
"This should be visible." in stdout_capture.getvalue()
), "stdout was incorrectly suppressed!"
assert ( # noqa: T201
"This error message should also be visible." in stderr_capture.getvalue()
), "stderr was incorrectly suppressed!"
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def find_source():
import inspect
import os

obj = inspect.unwrap(obj)

fn = inspect.getsourcefile(obj)
fn = os.path.relpath(fn, start=os.path.dirname(aeon.__file__))
source, lineno = inspect.getsourcelines(obj)
Expand Down