Skip to content

Commit 041136c

Browse files
committed
judge: implement instant aborts
The way this works is: - Worker creates a tempdir, and sets `tempfile.tempdir` to this directory. - Worker sends back the tempdir. The parent process is responsible for cleaning it up when the worker exits. Abortions are then implemented as sending `SIGKILL` to the worker. As a side benefit of this implementation, we also get to drop the hacky `CompiledExecutor` cache deletion.
1 parent 5eb5f59 commit 041136c

File tree

2 files changed

+48
-61
lines changed

2 files changed

+48
-61
lines changed

dmoj/graders/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ def grade(self, case):
2121
def _generate_binary(self):
2222
raise NotImplementedError
2323

24-
def abort_grading(self):
25-
self._abort_requested = True
26-
if self._current_proc:
27-
try:
28-
self._current_proc.kill()
29-
except OSError:
30-
pass
31-
3224
def _resolve_testcases(self, cfg, batch_no=0):
3325
cases = []
3426
for case_config in cfg:

dmoj/judge.py

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import logging
33
import multiprocessing
44
import os
5+
import shutil
56
import signal
67
import sys
8+
import tempfile
79
import threading
810
import traceback
911
from enum import Enum
@@ -41,14 +43,14 @@ class IPC(Enum):
4143
BATCH_END = 'BATCH-END'
4244
GRADING_BEGIN = 'GRADING-BEGIN'
4345
GRADING_END = 'GRADING-END'
44-
GRADING_ABORTED = 'GRADING-ABORTED'
4546
UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION'
46-
REQUEST_ABORT = 'REQUEST-ABORT'
47+
48+
49+
class JudgeWorkerAborted(Exception):
50+
pass
4751

4852

4953
# This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here.
50-
# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of
51-
# a `CompileError`.)
5254
IPC_TIMEOUT = 60 # seconds
5355

5456

@@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal
128130
)
129131
)
130132

131-
# FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send
132-
# an abort request down the pipe, possibly messing up the handshake.
133133
self.current_judge_worker = JudgeWorker(submission)
134134

135135
ipc_ready_signal = threading.Event()
@@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
147147
assert self.current_judge_worker is not None
148148

149149
try:
150+
worker_tempdir = None
151+
152+
def _ipc_hello(_report, tempdir: str):
153+
nonlocal worker_tempdir
154+
ipc_ready_signal.set()
155+
worker_tempdir = tempdir
156+
150157
ipc_handler_dispatch: Dict[IPC, Callable] = {
151-
IPC.HELLO: lambda _report: ipc_ready_signal.set(),
158+
IPC.HELLO: _ipc_hello,
152159
IPC.COMPILE_ERROR: self._ipc_compile_error,
153160
IPC.COMPILE_MESSAGE: self._ipc_compile_message,
154161
IPC.GRADING_BEGIN: self._ipc_grading_begin,
155162
IPC.GRADING_END: self._ipc_grading_end,
156-
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
157163
IPC.BATCH_BEGIN: self._ipc_batch_begin,
158164
IPC.BATCH_END: self._ipc_batch_end,
159165
IPC.RESULT: self._ipc_result,
@@ -176,12 +182,22 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
176182
% (self.current_submission.problem_id, self.current_submission.id)
177183
)
178184
)
185+
except JudgeWorkerAborted:
186+
self.packet_manager.submission_aborted_packet()
179187
except Exception: # noqa: E722, we want to catch everything
180188
self.log_internal_error()
181189
finally:
182190
self.current_judge_worker.wait_with_timeout()
183191
self.current_judge_worker = None
184192

193+
print('cleaning up', worker_tempdir)
194+
os.system('ls -al %s' % worker_tempdir)
195+
if worker_tempdir:
196+
try:
197+
shutil.rmtree(worker_tempdir)
198+
except: # noqa: E722
199+
pass
200+
185201
# Might not have been set if an exception was encountered before HELLO message, so signal here to keep the
186202
# other side from waiting forever.
187203
ipc_ready_signal.set()
@@ -232,10 +248,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None:
232248
def _ipc_batch_end(self, _report, _batch_number: int) -> None:
233249
self.packet_manager.batch_end_packet()
234250

235-
def _ipc_grading_aborted(self, report) -> None:
236-
self.packet_manager.submission_aborted_packet()
237-
report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)'))
238-
239251
def _ipc_unhandled_exception(self, _report, message: str) -> None:
240252
logger.error('Unhandled exception in worker process')
241253
self.log_internal_error(message=message)
@@ -254,10 +266,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None:
254266
'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id
255267
)
256268
else:
257-
logger.info('Received abortion request for %d', worker.submission.id)
258-
# These calls are idempotent, so it doesn't matter if we raced and the worker has exited already.
259-
worker.request_abort_grading()
260-
worker.wait_with_timeout()
269+
logger.info('Received abortion request for %d, killing worker', worker.submission.id)
270+
# This call is idempotent, so it doesn't matter if we raced and the worker has exited already.
271+
worker.abort_grading__kill_worker()
261272

262273
def listen(self) -> None:
263274
"""
@@ -270,7 +281,8 @@ def murder(self) -> None:
270281
"""
271282
End any submission currently executing, and exit the judge.
272283
"""
273-
self.abort_grading()
284+
if self.current_judge_worker:
285+
self.current_judge_worker.abort_grading__kill_worker()
274286
self.updater_exit = True
275287
self.updater_signal.set()
276288
if self.packet_manager:
@@ -304,8 +316,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio
304316
class JudgeWorker:
305317
def __init__(self, submission: Submission) -> None:
306318
self.submission = submission
307-
self._abort_requested = False
308-
self._sent_sigkill_to_worker_process = False
319+
self._aborted = False
320+
self._timed_out = False
309321
# FIXME(tbrindus): marked Any pending grader cleanups.
310322
self.grader: Any = None
311323

@@ -331,8 +343,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]:
331343
self.worker_process.kill()
332344
raise
333345
except EOFError:
334-
if self._sent_sigkill_to_worker_process:
335-
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT)
346+
if self._aborted:
347+
raise JudgeWorkerAborted() from None
348+
349+
if self._timed_out:
350+
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None
351+
336352
raise
337353
except Exception:
338354
logger.error('Failed to read IPC message from worker!')
@@ -354,16 +370,14 @@ def wait_with_timeout(self) -> None:
354370
finally:
355371
if self.worker_process.is_alive():
356372
logger.error('Worker is still alive, sending SIGKILL!')
357-
self._sent_sigkill_to_worker_process = True
373+
self._timed_out = True
358374
self.worker_process.kill()
359375

360-
def request_abort_grading(self) -> None:
361-
assert self.worker_process_conn
362-
363-
try:
364-
self.worker_process_conn.send((IPC.REQUEST_ABORT, ()))
365-
except Exception:
366-
logger.exception('Failed to send abort request to worker, did it race?')
376+
def abort_grading__kill_worker(self) -> None:
377+
if self.worker_process and self.worker_process.is_alive():
378+
self._aborted = True
379+
self.worker_process.kill()
380+
self.worker_process.join(timeout=1)
367381

368382
def _worker_process_main(
369383
self,
@@ -384,15 +398,12 @@ def _ipc_recv_thread_main() -> None:
384398
while True:
385399
try:
386400
ipc_type, data = judge_process_conn.recv()
387-
except: # noqa: E722, whatever happened, we have to abort now.
401+
except: # noqa: E722, whatever happened, we have to exit now.
388402
logger.exception('Judge unexpectedly hung up!')
389-
self._do_abort()
390403
return
391404

392405
if ipc_type == IPC.BYE:
393406
return
394-
elif ipc_type == IPC.REQUEST_ABORT:
395-
self._do_abort()
396407
else:
397408
raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),))
398409

@@ -402,9 +413,12 @@ def _report_unhandled_exception() -> None:
402413
judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,)))
403414
judge_process_conn.send((IPC.BYE, ()))
404415

416+
tempdir = tempfile.mkdtemp('dmoj-judge-worker')
417+
tempfile.tempdir = tempdir
418+
405419
ipc_recv_thread = None
406420
try:
407-
judge_process_conn.send((IPC.HELLO, ()))
421+
judge_process_conn.send((IPC.HELLO, (tempdir,)))
408422

409423
ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True)
410424
ipc_recv_thread.start()
@@ -439,15 +453,6 @@ def _report_unhandled_exception() -> None:
439453
if ipc_recv_thread.is_alive():
440454
logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!')
441455

442-
# FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which
443-
# won't get called if we exit the process right now (so we'd leak all files created by the grader). This
444-
# should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting
445-
# working out.
446-
from dmoj.executors.compiled_executor import _CompiledExecutorMeta
447-
448-
for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values():
449-
cached_executor.is_cached = False
450-
cached_executor.cleanup()
451456
self.grader = None
452457

453458
def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
@@ -503,11 +508,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
503508
else:
504509
result = self.grader.grade(case)
505510

506-
# If the submission was killed due to a user-initiated abort, any result is meaningless.
507-
if self._abort_requested:
508-
yield IPC.GRADING_ABORTED, ()
509-
return
510-
511511
if result.result_flag & Result.WA:
512512
# If we failed a 0-point case, we will short-circuit every case after this.
513513
is_short_circuiting_enabled |= not case.points
@@ -532,11 +532,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
532532

533533
yield IPC.GRADING_END, ()
534534

535-
def _do_abort(self) -> None:
536-
self._abort_requested = True
537-
if self.grader:
538-
self.grader.abort_grading()
539-
540535

541536
class ClassicJudge(Judge):
542537
def __init__(self, host, port, **kwargs) -> None:

0 commit comments

Comments
 (0)