Skip to content

Commit 7dcf589

Browse files
authored
Merge pull request #417 from lucasssvaz/fix/multi_print
fix: Add lock for simultaneous printing
2 parents 62edc22 + 5d225e6 commit 7dcf589

3 files changed

Lines changed: 238 additions & 7 deletions

File tree

pytest-embedded/pytest_embedded/dut_factory.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import datetime
23
import gc
34
import io
@@ -42,8 +43,19 @@ def _drop_none_kwargs(kwargs: dict[t.Any, t.Any]):
4243
PARAMETRIZED_FIXTURES_CACHE = {}
4344

4445

45-
def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1) -> None:
46+
_STDOUT_LOCK = None
47+
48+
49+
def set_stdout_lock(lock) -> None:
50+
global _STDOUT_LOCK
51+
_STDOUT_LOCK = lock
52+
53+
54+
def _listen(
55+
q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1, _stdout_lock=None
56+
) -> None:
4657
shall_add_prefix = True
58+
_pending = ''
4759
while True:
4860
msg = q.get()
4961
if not msg:
@@ -71,20 +83,25 @@ def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count:
7183
if _s.endswith('\n'): # complete line
7284
shall_add_prefix = True
7385
_s = _s[:-1].replace('\n', '\n' + prefix) + '\n'
86+
with _stdout_lock if _stdout_lock else contextlib.nullcontext():
87+
_stdout.write(_pending + _s)
88+
_stdout.flush()
89+
_pending = ''
7490
else:
7591
shall_add_prefix = False
7692
_s = _s.replace('\n', '\n' + prefix)
77-
78-
_stdout.write(_s)
79-
_stdout.flush()
93+
_pending += _s
8094

8195

82-
def _listener_gn(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process:
96+
def _listener_gn(
97+
msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock=None
98+
) -> multiprocessing.Process:
8399
os.makedirs(os.path.dirname(_pexpect_logfile), exist_ok=True)
84100
kwargs = {
85101
'with_timestamp': with_timestamp,
86102
'count': dut_index,
87103
'total': dut_total,
104+
'_stdout_lock': _stdout_lock,
88105
}
89106

90107
return _ctx.Process(
@@ -753,7 +770,9 @@ def create(
753770
)
754771
logging.debug('You can get your custom DUT log file at the following path: %s.', _pexpect_logfile)
755772

756-
_listener = _listener_gn(msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1)
773+
_listener = _listener_gn(
774+
msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1, _stdout_lock=_STDOUT_LOCK
775+
)
757776
layout.append(_listener)
758777

759778
_pexpect_fr = _pexpect_fr_gn(_pexpect_logfile, _listener)

pytest-embedded/pytest_embedded/plugin.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .dut import Dut
3131
from .dut_factory import (
3232
DutFactory,
33+
_ctx,
3334
_fixture_classes_and_options_fn,
3435
_listener_gn,
3536
_pexpect_fr_gn,
@@ -41,6 +42,7 @@
4142
qemu_gn,
4243
serial_gn,
4344
set_parametrized_fixtures_cache,
45+
set_stdout_lock,
4446
wokwi_gn,
4547
)
4648
from .log import MessageQueue, MessageQueueManager, PexpectProcess
@@ -695,6 +697,23 @@ def _mp_manager():
695697
manager.shutdown()
696698

697699

700+
@pytest.fixture(scope='session', autouse=True)
701+
def _stdout_lock():
702+
"""
703+
A session-scoped multiprocessing lock used to serialize stdout writes across
704+
all DUT listener processes, preventing garbled output when multiple DUTs
705+
print to stdout simultaneously.
706+
707+
It is marked ``autouse=True`` so that the lock is created and registered
708+
globally (via ``set_stdout_lock``) before any DUT fixture is instantiated,
709+
ensuring every listener process receives a valid lock reference regardless
710+
of test ordering.
711+
"""
712+
lock = _ctx.Lock()
713+
set_stdout_lock(lock)
714+
yield lock
715+
716+
698717
@pytest.fixture
699718
def test_case_tempdir(test_case_name: str, session_tempdir: str) -> str:
700719
"""Function scoped temp dir for pytest-embedded"""
@@ -746,13 +765,17 @@ def with_timestamp(request: FixtureRequest) -> bool:
746765

747766
@pytest.fixture
748767
@multi_dut_generator_fixture
749-
def _listener(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process:
768+
def _listener(
769+
msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock
770+
) -> multiprocessing.Process:
750771
"""
751772
The listener would create a `_listen` process. The `_listen` process would get the string from the message queue,
752773
and do two things together:
753774
754775
1. print the string to `sys.stdout`
755776
2. write the string to `_pexpect_logfile`
777+
778+
A shared lock (_stdout_lock) is used to prevent interleaved output when multiple DUTs print simultaneously.
756779
"""
757780
return _listener_gn(**locals())
758781

pytest-embedded/tests/test_base.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,3 +821,192 @@ def test_metric_no_path(log_metric):
821821

822822
result = pytester.runpytest()
823823
result.assert_outcomes(passed=1)
824+
825+
826+
# ---------------------------------------------------------------------------
827+
# Tests for the stdout-lock feature (_stdout_lock / set_stdout_lock / _listen)
828+
# ---------------------------------------------------------------------------
829+
830+
831+
def test_set_stdout_lock():
832+
"""set_stdout_lock updates the module-level _STDOUT_LOCK variable."""
833+
import pytest_embedded.dut_factory as m
834+
from pytest_embedded.dut_factory import set_stdout_lock
835+
836+
original = m._STDOUT_LOCK
837+
try:
838+
sentinel = object()
839+
set_stdout_lock(sentinel)
840+
assert m._STDOUT_LOCK is sentinel
841+
842+
set_stdout_lock(None)
843+
assert m._STDOUT_LOCK is None
844+
finally:
845+
set_stdout_lock(original)
846+
847+
848+
def test_listen_no_data_loss_without_lock(tmp_path):
849+
"""_listen writes every queued message to the logfile when no lock is used."""
850+
import time
851+
852+
from pytest_embedded.dut_factory import _ctx, _listen
853+
from pytest_embedded.log import MessageQueue
854+
855+
logfile = str(tmp_path / 'test.log')
856+
q = MessageQueue()
857+
messages = [f'line_{i}\n'.encode() for i in range(20)]
858+
859+
p = _ctx.Process(target=_listen, args=(q, logfile), kwargs={'with_timestamp': False})
860+
p.start()
861+
try:
862+
for msg in messages:
863+
q.put(msg)
864+
865+
deadline = time.monotonic() + 10
866+
while time.monotonic() < deadline:
867+
try:
868+
content = open(logfile, 'rb').read()
869+
if all(msg in content for msg in messages):
870+
break
871+
except OSError:
872+
pass
873+
time.sleep(0.05)
874+
finally:
875+
p.terminate()
876+
p.join(timeout=5)
877+
assert p.exitcode is not None, 'listener process did not terminate'
878+
879+
content = open(logfile, 'rb').read()
880+
for msg in messages:
881+
assert msg in content, f'{msg!r} missing from logfile'
882+
883+
884+
def test_listen_no_data_loss_with_lock(tmp_path):
885+
"""_listen writes every queued message to the logfile when a Manager lock is used."""
886+
import multiprocessing
887+
import time
888+
889+
from pytest_embedded.dut_factory import _ctx, _listen
890+
from pytest_embedded.log import MessageQueue
891+
892+
logfile = str(tmp_path / 'test.log')
893+
q = MessageQueue()
894+
messages = [f'line_{i}\n'.encode() for i in range(20)]
895+
896+
manager = multiprocessing.Manager()
897+
try:
898+
lock = manager.Lock()
899+
p = _ctx.Process(
900+
target=_listen,
901+
args=(q, logfile),
902+
kwargs={'with_timestamp': False, '_stdout_lock': lock},
903+
)
904+
p.start()
905+
try:
906+
for msg in messages:
907+
q.put(msg)
908+
909+
deadline = time.monotonic() + 10
910+
while time.monotonic() < deadline:
911+
try:
912+
content = open(logfile, 'rb').read()
913+
if all(msg in content for msg in messages):
914+
break
915+
except OSError:
916+
pass
917+
time.sleep(0.05)
918+
finally:
919+
p.terminate()
920+
p.join(timeout=5)
921+
assert p.exitcode is not None, 'listener process did not terminate'
922+
finally:
923+
manager.shutdown()
924+
925+
content = open(logfile, 'rb').read()
926+
for msg in messages:
927+
assert msg in content, f'{msg!r} missing from logfile'
928+
929+
930+
def test_stdout_lock_concurrent_no_data_loss(tmp_path):
931+
"""Two concurrent _listen processes sharing a Manager lock both preserve all data."""
932+
import multiprocessing
933+
import time
934+
935+
from pytest_embedded.dut_factory import _ctx, _listen
936+
from pytest_embedded.log import MessageQueue
937+
938+
logfile0 = str(tmp_path / 'dut0.log')
939+
logfile1 = str(tmp_path / 'dut1.log')
940+
q0 = MessageQueue()
941+
q1 = MessageQueue()
942+
messages0 = [f'dut0_line_{i}\n'.encode() for i in range(20)]
943+
messages1 = [f'dut1_line_{i}\n'.encode() for i in range(20)]
944+
945+
manager = multiprocessing.Manager()
946+
try:
947+
lock = manager.Lock()
948+
p0 = _ctx.Process(
949+
target=_listen,
950+
args=(q0, logfile0),
951+
kwargs={'with_timestamp': False, 'count': 1, 'total': 2, '_stdout_lock': lock},
952+
)
953+
p1 = _ctx.Process(
954+
target=_listen,
955+
args=(q1, logfile1),
956+
kwargs={'with_timestamp': False, 'count': 2, 'total': 2, '_stdout_lock': lock},
957+
)
958+
p0.start()
959+
p1.start()
960+
try:
961+
# interleave writes from both DUTs to maximize lock contention
962+
for msg0, msg1 in zip(messages0, messages1):
963+
q0.put(msg0)
964+
q1.put(msg1)
965+
966+
deadline = time.monotonic() + 15
967+
while time.monotonic() < deadline:
968+
try:
969+
c0 = open(logfile0, 'rb').read()
970+
c1 = open(logfile1, 'rb').read()
971+
if all(m in c0 for m in messages0) and all(m in c1 for m in messages1):
972+
break
973+
except OSError:
974+
pass
975+
time.sleep(0.05)
976+
finally:
977+
p0.terminate()
978+
p1.terminate()
979+
p0.join(timeout=5)
980+
p1.join(timeout=5)
981+
assert p0.exitcode is not None, 'dut0 listener process did not terminate'
982+
assert p1.exitcode is not None, 'dut1 listener process did not terminate'
983+
finally:
984+
manager.shutdown()
985+
986+
c0 = open(logfile0, 'rb').read()
987+
c1 = open(logfile1, 'rb').read()
988+
for msg in messages0:
989+
assert msg in c0, f'{msg!r} missing from dut0 logfile'
990+
for msg in messages1:
991+
assert msg in c1, f'{msg!r} missing from dut1 logfile'
992+
993+
994+
def test_multi_dut_no_data_loss(testdir):
995+
"""In a 2-DUT test, all messages written by each DUT can be expected - nothing is dropped."""
996+
testdir.makepyfile(r"""
997+
import pytest
998+
999+
@pytest.mark.parametrize('count', [2], indirect=True)
1000+
def test_concurrent_dut_writes(dut):
1001+
n = 15
1002+
for i in range(n):
1003+
dut[0].write(f'dut0_msg_{i}')
1004+
dut[1].write(f'dut1_msg_{i}')
1005+
1006+
for i in range(n):
1007+
dut[0].expect_exact(f'dut0_msg_{i}')
1008+
dut[1].expect_exact(f'dut1_msg_{i}')
1009+
""")
1010+
1011+
result = testdir.runpytest()
1012+
result.assert_outcomes(passed=1)

0 commit comments

Comments
 (0)