|
4 | 4 |
|
5 | 5 | """Test the DVSim scheduler.""" |
6 | 6 |
|
| 7 | +import multiprocessing |
| 8 | +import os |
| 9 | +import sys |
| 10 | +import threading |
7 | 11 | import time |
8 | 12 | from collections.abc import Callable, Sequence |
9 | 13 | from dataclasses import dataclass |
10 | 14 | from pathlib import Path |
| 15 | +from signal import SIGINT, SIGTERM, signal |
| 16 | +from types import FrameType |
11 | 17 | from typing import Any |
12 | 18 |
|
13 | 19 | import pytest |
|
52 | 58 |
|
53 | 59 | # Default scheduler test timeout to handle infinite loops in the scheduler |
54 | 60 | DEFAULT_TIMEOUT = 0.5 |
| 61 | +SIGNAL_TEST_TIMEOUT = 2.5 |
55 | 62 |
|
56 | 63 |
|
57 | 64 | @dataclass |
@@ -926,3 +933,102 @@ def test_blocked_weight_starvation(fxt: Fxt) -> None: |
926 | 933 | # |
927 | 934 | # Note also that DVSim currently assumes weights within a target are constant, |
928 | 935 | # which may not be the case with the current JobSpec model. |
| 936 | + |
| 937 | + |
| 938 | +class TestSignals: |
| 939 | + """Integration tests for the signal-handling of the scheduler.""" |
| 940 | + |
| 941 | + @staticmethod |
| 942 | + def _run_signal_test(tmp_path: Path, sig: int, *, repeat: bool, long_poll: bool) -> None: |
| 943 | + """Test that the scheduler can be gracefully killed by incoming signals.""" |
| 944 | + |
| 945 | + # We cannot access the fixtures from the separate process, so define a minimal |
| 946 | + # mock launcher class here. |
| 947 | + class SignalTestMockLauncher(MockLauncher): |
| 948 | + pass |
| 949 | + |
| 950 | + mock_ctx = MockLauncherContext() |
| 951 | + SignalTestMockLauncher.mock_context = mock_ctx |
| 952 | + SignalTestMockLauncher.max_parallel = 2 |
| 953 | + if long_poll: |
| 954 | + # Set a very long poll frequency to be sure that the signal interrupts the |
| 955 | + # scheduler from a sleep if configured with infrequent polls. |
| 956 | + SignalTestMockLauncher.poll_freq = 360000 |
| 957 | + |
| 958 | + jobs = make_many_jobs(tmp_path, 3, ensure_paths_exist=True) |
| 959 | + # When testing non-graceful exits, we make `kill()` hang and send two signals. |
| 960 | + kill_time = None if not repeat else 100.0 |
| 961 | + # Job 0 is permanently "dispatched", it never completes. |
| 962 | + mock_ctx.set_config( |
| 963 | + jobs[0], MockJob(default_status=JobStatus.DISPATCHED, kill_time=kill_time) |
| 964 | + ) |
| 965 | + # Job 1 will pass, but after a long time (a large number of polls). |
| 966 | + mock_ctx.set_config( |
| 967 | + jobs[1], |
| 968 | + MockJob( |
| 969 | + status_thresholds=[(0, JobStatus.DISPATCHED), (1000000000, JobStatus.PASSED)], |
| 970 | + kill_time=kill_time, |
| 971 | + ), |
| 972 | + ) |
| 973 | + # Job 2 is also permanently "dispatched", but will never run due to the |
| 974 | + # max paralellism limit on the launcher. It will instead be cancelled. |
| 975 | + mock_ctx.set_config( |
| 976 | + jobs[2], MockJob(default_status=JobStatus.DISPATCHED, kill_time=kill_time) |
| 977 | + ) |
| 978 | + scheduler = Scheduler(jobs, SignalTestMockLauncher) |
| 979 | + |
| 980 | + def _get_signal(sig_received: int, _: FrameType | None) -> None: |
| 981 | + assert_that(sig_received, equal_to(sig)) |
| 982 | + assert_that(repeat) |
| 983 | + sys.exit(0) |
| 984 | + |
| 985 | + if repeat: |
| 986 | + # Sending multiple signals will call the regular signal handler |
| 987 | + # which will kill the process. Register a mock handler to stop |
| 988 | + # that happening and we can check that we "killed the process". |
| 989 | + signal(sig, _get_signal) |
| 990 | + |
| 991 | + def _send_signals() -> None: |
| 992 | + # Give time for the handler to be installed and jobs to dispatch |
| 993 | + # and for the main loop to enter a sleep/wait. |
| 994 | + wait_time = 0.1 |
| 995 | + time.sleep(wait_time) |
| 996 | + pid = os.getpid() |
| 997 | + os.kill(pid, sig) |
| 998 | + if repeat: |
| 999 | + time.sleep(wait_time) |
| 1000 | + os.kill(pid, sig) |
| 1001 | + |
| 1002 | + # Send signals from a separate thread |
| 1003 | + threading.Thread(target=_send_signals).start() |
| 1004 | + result = scheduler.run() |
| 1005 | + |
| 1006 | + # If we didn't reach `_get_signal`, this should be a graceful exit |
| 1007 | + assert_that(not repeat) |
| 1008 | + _assert_result_status(result, 3, expected=JobStatus.KILLED) |
| 1009 | + |
| 1010 | + @staticmethod |
| 1011 | + @pytest.mark.xfail( |
| 1012 | + reason="This test passes ~95 percent of the time, but the logging & threading primitive" |
| 1013 | + "logic used in the signal handler are not async-signal-safe and thus may deadlock," |
| 1014 | + "causing the process to hang and time out instead.", |
| 1015 | + strict=False, |
| 1016 | + ) |
| 1017 | + @pytest.mark.parametrize("long_poll", [False, True]) |
| 1018 | + @pytest.mark.parametrize(("sig", "repeat"), [(SIGTERM, False), (SIGINT, False), (SIGINT, True)]) |
| 1019 | + def test_signal_kill(tmp_path: Path, *, sig: int, repeat: bool, long_poll: bool) -> None: |
| 1020 | + """Test that the scheduler can be gracefully killed by incoming signals.""" |
| 1021 | + # We must test in a separate process, otherwise pytest interprets the SIGINT and SIGTERM |
| 1022 | + # signals using its own signal handlers as signals to quit pytest itself... |
| 1023 | + proc = multiprocessing.Process( |
| 1024 | + target=TestSignals._run_signal_test, |
| 1025 | + args=(tmp_path, sig), |
| 1026 | + kwargs={"repeat": repeat, "long_poll": long_poll}, |
| 1027 | + ) |
| 1028 | + proc.start() |
| 1029 | + proc.join(timeout=SIGNAL_TEST_TIMEOUT) |
| 1030 | + if proc.is_alive(): |
| 1031 | + proc.kill() # SIGKILL instead of SIGINT or SIGTERM |
| 1032 | + proc.join() |
| 1033 | + pytest.fail("Scheduler hung and was terminated") |
| 1034 | + assert_that(proc.exitcode, equal_to(0)) |
0 commit comments