Skip to content

Commit 662a312

Browse files
committed
judge: implement instant aborts
The way this works is: - Worker creates a tempdir, and sets `tempfile.tempdir` to this directory. - Worker sends back the tempdir. The parent process is responsible for cleaning it up when the worker exits. Abortions are then implemented as sending `SIGKILL` to the worker. As a side benefit of this implementation, we also get to drop the hacky `CompiledExecutor` cache deletion.
1 parent 5eb5f59 commit 662a312

File tree

2 files changed

+43
-61
lines changed

2 files changed

+43
-61
lines changed

dmoj/graders/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ def grade(self, case):
2121
def _generate_binary(self):
2222
raise NotImplementedError
2323

24-
def abort_grading(self):
25-
self._abort_requested = True
26-
if self._current_proc:
27-
try:
28-
self._current_proc.kill()
29-
except OSError:
30-
pass
31-
3224
def _resolve_testcases(self, cfg, batch_no=0):
3325
cases = []
3426
for case_config in cfg:

dmoj/judge.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import logging
33
import multiprocessing
44
import os
5+
import shutil
56
import signal
67
import sys
8+
import tempfile
79
import threading
810
import traceback
911
from enum import Enum
@@ -41,14 +43,14 @@ class IPC(Enum):
4143
BATCH_END = 'BATCH-END'
4244
GRADING_BEGIN = 'GRADING-BEGIN'
4345
GRADING_END = 'GRADING-END'
44-
GRADING_ABORTED = 'GRADING-ABORTED'
4546
UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION'
46-
REQUEST_ABORT = 'REQUEST-ABORT'
47+
48+
49+
class JudgeWorkerAborted(Exception):
50+
pass
4751

4852

4953
# This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here.
50-
# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of
51-
# a `CompileError`.)
5254
IPC_TIMEOUT = 60 # seconds
5355

5456

@@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal
128130
)
129131
)
130132

131-
# FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send
132-
# an abort request down the pipe, possibly messing up the handshake.
133133
self.current_judge_worker = JudgeWorker(submission)
134134

135135
ipc_ready_signal = threading.Event()
@@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
147147
assert self.current_judge_worker is not None
148148

149149
try:
150+
worker_tempdir = None
151+
152+
def _ipc_hello(_report, tempdir: str):
153+
nonlocal worker_tempdir
154+
ipc_ready_signal.set()
155+
worker_tempdir = tempdir
156+
150157
ipc_handler_dispatch: Dict[IPC, Callable] = {
151-
IPC.HELLO: lambda _report: ipc_ready_signal.set(),
158+
IPC.HELLO: _ipc_hello,
152159
IPC.COMPILE_ERROR: self._ipc_compile_error,
153160
IPC.COMPILE_MESSAGE: self._ipc_compile_message,
154161
IPC.GRADING_BEGIN: self._ipc_grading_begin,
155162
IPC.GRADING_END: self._ipc_grading_end,
156-
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
157163
IPC.BATCH_BEGIN: self._ipc_batch_begin,
158164
IPC.BATCH_END: self._ipc_batch_end,
159165
IPC.RESULT: self._ipc_result,
@@ -176,12 +182,17 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
176182
% (self.current_submission.problem_id, self.current_submission.id)
177183
)
178184
)
185+
except JudgeWorkerAborted:
186+
self.packet_manager.submission_aborted_packet()
179187
except Exception: # noqa: E722, we want to catch everything
180188
self.log_internal_error()
181189
finally:
182190
self.current_judge_worker.wait_with_timeout()
183191
self.current_judge_worker = None
184192

193+
if worker_tempdir:
194+
shutil.rmtree(worker_tempdir)
195+
185196
# Might not have been set if an exception was encountered before HELLO message, so signal here to keep the
186197
# other side from waiting forever.
187198
ipc_ready_signal.set()
@@ -232,10 +243,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None:
232243
def _ipc_batch_end(self, _report, _batch_number: int) -> None:
233244
self.packet_manager.batch_end_packet()
234245

235-
def _ipc_grading_aborted(self, report) -> None:
236-
self.packet_manager.submission_aborted_packet()
237-
report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)'))
238-
239246
def _ipc_unhandled_exception(self, _report, message: str) -> None:
240247
logger.error('Unhandled exception in worker process')
241248
self.log_internal_error(message=message)
@@ -254,10 +261,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None:
254261
'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id
255262
)
256263
else:
257-
logger.info('Received abortion request for %d', worker.submission.id)
258-
# These calls are idempotent, so it doesn't matter if we raced and the worker has exited already.
259-
worker.request_abort_grading()
260-
worker.wait_with_timeout()
264+
logger.info('Received abortion request for %d, killing worker', worker.submission.id)
265+
# This call is idempotent, so it doesn't matter if we raced and the worker has exited already.
266+
worker.abort_grading__kill_worker()
261267

262268
def listen(self) -> None:
263269
"""
@@ -270,7 +276,8 @@ def murder(self) -> None:
270276
"""
271277
End any submission currently executing, and exit the judge.
272278
"""
273-
self.abort_grading()
279+
if self.current_judge_worker:
280+
self.current_judge_worker.abort_grading__kill_worker()
274281
self.updater_exit = True
275282
self.updater_signal.set()
276283
if self.packet_manager:
@@ -304,8 +311,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio
304311
class JudgeWorker:
305312
def __init__(self, submission: Submission) -> None:
306313
self.submission = submission
307-
self._abort_requested = False
308-
self._sent_sigkill_to_worker_process = False
314+
self._aborted = False
315+
self._timed_out = False
309316
# FIXME(tbrindus): marked Any pending grader cleanups.
310317
self.grader: Any = None
311318

@@ -331,8 +338,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]:
331338
self.worker_process.kill()
332339
raise
333340
except EOFError:
334-
if self._sent_sigkill_to_worker_process:
335-
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT)
341+
if self._aborted:
342+
raise JudgeWorkerAborted() from None
343+
344+
if self._timed_out:
345+
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None
346+
336347
raise
337348
except Exception:
338349
logger.error('Failed to read IPC message from worker!')
@@ -354,16 +365,14 @@ def wait_with_timeout(self) -> None:
354365
finally:
355366
if self.worker_process.is_alive():
356367
logger.error('Worker is still alive, sending SIGKILL!')
357-
self._sent_sigkill_to_worker_process = True
368+
self._timed_out = True
358369
self.worker_process.kill()
359370

360-
def request_abort_grading(self) -> None:
361-
assert self.worker_process_conn
362-
363-
try:
364-
self.worker_process_conn.send((IPC.REQUEST_ABORT, ()))
365-
except Exception:
366-
logger.exception('Failed to send abort request to worker, did it race?')
371+
def abort_grading__kill_worker(self) -> None:
372+
if self.worker_process and self.worker_process.is_alive():
373+
self._aborted = True
374+
self.worker_process.kill()
375+
self.worker_process.join(timeout=1)
367376

368377
def _worker_process_main(
369378
self,
@@ -384,15 +393,12 @@ def _ipc_recv_thread_main() -> None:
384393
while True:
385394
try:
386395
ipc_type, data = judge_process_conn.recv()
387-
except: # noqa: E722, whatever happened, we have to abort now.
396+
except: # noqa: E722, whatever happened, we have to exit now.
388397
logger.exception('Judge unexpectedly hung up!')
389-
self._do_abort()
390398
return
391399

392400
if ipc_type == IPC.BYE:
393401
return
394-
elif ipc_type == IPC.REQUEST_ABORT:
395-
self._do_abort()
396402
else:
397403
raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),))
398404

@@ -402,9 +408,12 @@ def _report_unhandled_exception() -> None:
402408
judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,)))
403409
judge_process_conn.send((IPC.BYE, ()))
404410

411+
tempdir = tempfile.mkdtemp('dmoj-judge-worker')
412+
tempfile.tempdir = tempdir
413+
405414
ipc_recv_thread = None
406415
try:
407-
judge_process_conn.send((IPC.HELLO, ()))
416+
judge_process_conn.send((IPC.HELLO, (tempdir,)))
408417

409418
ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True)
410419
ipc_recv_thread.start()
@@ -439,15 +448,6 @@ def _report_unhandled_exception() -> None:
439448
if ipc_recv_thread.is_alive():
440449
logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!')
441450

442-
# FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which
443-
# won't get called if we exit the process right now (so we'd leak all files created by the grader). This
444-
# should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting
445-
# working out.
446-
from dmoj.executors.compiled_executor import _CompiledExecutorMeta
447-
448-
for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values():
449-
cached_executor.is_cached = False
450-
cached_executor.cleanup()
451451
self.grader = None
452452

453453
def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
@@ -503,11 +503,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
503503
else:
504504
result = self.grader.grade(case)
505505

506-
# If the submission was killed due to a user-initiated abort, any result is meaningless.
507-
if self._abort_requested:
508-
yield IPC.GRADING_ABORTED, ()
509-
return
510-
511506
if result.result_flag & Result.WA:
512507
# If we failed a 0-point case, we will short-circuit every case after this.
513508
is_short_circuiting_enabled |= not case.points
@@ -532,11 +527,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
532527

533528
yield IPC.GRADING_END, ()
534529

535-
def _do_abort(self) -> None:
536-
self._abort_requested = True
537-
if self.grader:
538-
self.grader.abort_grading()
539-
540530

541531
class ClassicJudge(Judge):
542532
def __init__(self, host, port, **kwargs) -> None:

0 commit comments

Comments
 (0)