Skip to content

Commit 03926a7

Browse files
committed
judge: implement instant aborts
The way this works is: - Worker creates a tempdir, and sets `tempfile.tempdir` to this directory. - Worker sends back the tempdir. The parent process is responsible for cleaning it up when the worker exits. Abortions are then implemented as sending `SIGKILL` to the worker. As a side benefit of this implementation, we also get to drop the hacky `CompiledExecutor` cache deletion.
1 parent 5eb5f59 commit 03926a7

File tree

2 files changed

+47
-61
lines changed

2 files changed

+47
-61
lines changed

dmoj/graders/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ def grade(self, case):
2121
def _generate_binary(self):
2222
raise NotImplementedError
2323

24-
def abort_grading(self):
25-
self._abort_requested = True
26-
if self._current_proc:
27-
try:
28-
self._current_proc.kill()
29-
except OSError:
30-
pass
31-
3224
def _resolve_testcases(self, cfg, batch_no=0):
3325
cases = []
3426
for case_config in cfg:

dmoj/judge.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import logging
33
import multiprocessing
44
import os
5+
import shutil
56
import signal
67
import sys
8+
import tempfile
79
import threading
810
import traceback
911
from enum import Enum
@@ -41,14 +43,13 @@ class IPC(Enum):
4143
BATCH_END = 'BATCH-END'
4244
GRADING_BEGIN = 'GRADING-BEGIN'
4345
GRADING_END = 'GRADING-END'
44-
GRADING_ABORTED = 'GRADING-ABORTED'
4546
UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION'
46-
REQUEST_ABORT = 'REQUEST-ABORT'
4747

4848

49+
class JudgeWorkerAborted(Exception):
50+
pass
51+
4952
# This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here.
50-
# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of
51-
# a `CompileError`.)
5253
IPC_TIMEOUT = 60 # seconds
5354

5455

@@ -128,8 +129,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal
128129
)
129130
)
130131

131-
# FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send
132-
# an abort request down the pipe, possibly messing up the handshake.
133132
self.current_judge_worker = JudgeWorker(submission)
134133

135134
ipc_ready_signal = threading.Event()
@@ -147,13 +146,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
147146
assert self.current_judge_worker is not None
148147

149148
try:
149+
worker_tempdir = None
150+
151+
def _ipc_hello(_report, tempdir: str):
152+
nonlocal worker_tempdir
153+
ipc_ready_signal.set()
154+
worker_tempdir = tempdir
155+
150156
ipc_handler_dispatch: Dict[IPC, Callable] = {
151-
IPC.HELLO: lambda _report: ipc_ready_signal.set(),
157+
IPC.HELLO: _ipc_hello,
152158
IPC.COMPILE_ERROR: self._ipc_compile_error,
153159
IPC.COMPILE_MESSAGE: self._ipc_compile_message,
154160
IPC.GRADING_BEGIN: self._ipc_grading_begin,
155161
IPC.GRADING_END: self._ipc_grading_end,
156-
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
157162
IPC.BATCH_BEGIN: self._ipc_batch_begin,
158163
IPC.BATCH_END: self._ipc_batch_end,
159164
IPC.RESULT: self._ipc_result,
@@ -176,12 +181,22 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
176181
% (self.current_submission.problem_id, self.current_submission.id)
177182
)
178183
)
184+
except JudgeWorkerAborted:
185+
self.packet_manager.submission_aborted_packet()
179186
except Exception: # noqa: E722, we want to catch everything
180187
self.log_internal_error()
181188
finally:
182189
self.current_judge_worker.wait_with_timeout()
183190
self.current_judge_worker = None
184191

192+
print('cleaning up', worker_tempdir)
193+
os.system('ls -al %s' % worker_tempdir)
194+
if worker_tempdir:
195+
try:
196+
shutil.rmtree(worker_tempdir)
197+
except: # noqa: E722
198+
pass
199+
185200
# Might not have been set if an exception was encountered before HELLO message, so signal here to keep the
186201
# other side from waiting forever.
187202
ipc_ready_signal.set()
@@ -232,10 +247,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None:
232247
def _ipc_batch_end(self, _report, _batch_number: int) -> None:
233248
self.packet_manager.batch_end_packet()
234249

235-
def _ipc_grading_aborted(self, report) -> None:
236-
self.packet_manager.submission_aborted_packet()
237-
report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)'))
238-
239250
def _ipc_unhandled_exception(self, _report, message: str) -> None:
240251
logger.error('Unhandled exception in worker process')
241252
self.log_internal_error(message=message)
@@ -254,10 +265,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None:
254265
'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id
255266
)
256267
else:
257-
logger.info('Received abortion request for %d', worker.submission.id)
258-
# These calls are idempotent, so it doesn't matter if we raced and the worker has exited already.
259-
worker.request_abort_grading()
260-
worker.wait_with_timeout()
268+
logger.info('Received abortion request for %d, killing worker', worker.submission.id)
269+
# This call is idempotent, so it doesn't matter if we raced and the worker has exited already.
270+
worker.abort_grading__kill_worker()
261271

262272
def listen(self) -> None:
263273
"""
@@ -270,7 +280,8 @@ def murder(self) -> None:
270280
"""
271281
End any submission currently executing, and exit the judge.
272282
"""
273-
self.abort_grading()
283+
if self.current_judge_worker:
284+
self.current_judge_worker.abort_grading__kill_worker()
274285
self.updater_exit = True
275286
self.updater_signal.set()
276287
if self.packet_manager:
@@ -304,8 +315,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio
304315
class JudgeWorker:
305316
def __init__(self, submission: Submission) -> None:
306317
self.submission = submission
307-
self._abort_requested = False
308-
self._sent_sigkill_to_worker_process = False
318+
self._aborted = False
319+
self._timed_out = False
309320
# FIXME(tbrindus): marked Any pending grader cleanups.
310321
self.grader: Any = None
311322

@@ -331,8 +342,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]:
331342
self.worker_process.kill()
332343
raise
333344
except EOFError:
334-
if self._sent_sigkill_to_worker_process:
335-
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT)
345+
if self._aborted:
346+
raise JudgeWorkerAborted() from None
347+
348+
if self._timed_out:
349+
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None
350+
336351
raise
337352
except Exception:
338353
logger.error('Failed to read IPC message from worker!')
@@ -354,16 +369,14 @@ def wait_with_timeout(self) -> None:
354369
finally:
355370
if self.worker_process.is_alive():
356371
logger.error('Worker is still alive, sending SIGKILL!')
357-
self._sent_sigkill_to_worker_process = True
372+
self._timed_out = True
358373
self.worker_process.kill()
359374

360-
def request_abort_grading(self) -> None:
361-
assert self.worker_process_conn
362-
363-
try:
364-
self.worker_process_conn.send((IPC.REQUEST_ABORT, ()))
365-
except Exception:
366-
logger.exception('Failed to send abort request to worker, did it race?')
375+
def abort_grading__kill_worker(self) -> None:
376+
if self.worker_process and self.worker_process.is_alive():
377+
self._aborted = True
378+
self.worker_process.kill()
379+
self.worker_process.join(timeout=1)
367380

368381
def _worker_process_main(
369382
self,
@@ -384,15 +397,12 @@ def _ipc_recv_thread_main() -> None:
384397
while True:
385398
try:
386399
ipc_type, data = judge_process_conn.recv()
387-
except: # noqa: E722, whatever happened, we have to abort now.
400+
except: # noqa: E722, whatever happened, we have to exit now.
388401
logger.exception('Judge unexpectedly hung up!')
389-
self._do_abort()
390402
return
391403

392404
if ipc_type == IPC.BYE:
393405
return
394-
elif ipc_type == IPC.REQUEST_ABORT:
395-
self._do_abort()
396406
else:
397407
raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),))
398408

@@ -402,9 +412,12 @@ def _report_unhandled_exception() -> None:
402412
judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,)))
403413
judge_process_conn.send((IPC.BYE, ()))
404414

415+
tempdir = tempfile.mkdtemp('dmoj-judge-worker')
416+
tempfile.tempdir = tempdir
417+
405418
ipc_recv_thread = None
406419
try:
407-
judge_process_conn.send((IPC.HELLO, ()))
420+
judge_process_conn.send((IPC.HELLO, (tempdir,)))
408421

409422
ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True)
410423
ipc_recv_thread.start()
@@ -439,15 +452,6 @@ def _report_unhandled_exception() -> None:
439452
if ipc_recv_thread.is_alive():
440453
logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!')
441454

442-
# FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which
443-
# won't get called if we exit the process right now (so we'd leak all files created by the grader). This
444-
# should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting
445-
# working out.
446-
from dmoj.executors.compiled_executor import _CompiledExecutorMeta
447-
448-
for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values():
449-
cached_executor.is_cached = False
450-
cached_executor.cleanup()
451455
self.grader = None
452456

453457
def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
@@ -503,11 +507,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
503507
else:
504508
result = self.grader.grade(case)
505509

506-
# If the submission was killed due to a user-initiated abort, any result is meaningless.
507-
if self._abort_requested:
508-
yield IPC.GRADING_ABORTED, ()
509-
return
510-
511510
if result.result_flag & Result.WA:
512511
# If we failed a 0-point case, we will short-circuit every case after this.
513512
is_short_circuiting_enabled |= not case.points
@@ -532,11 +531,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
532531

533532
yield IPC.GRADING_END, ()
534533

535-
def _do_abort(self) -> None:
536-
self._abort_requested = True
537-
if self.grader:
538-
self.grader.abort_grading()
539-
540534

541535
class ClassicJudge(Judge):
542536
def __init__(self, host, port, **kwargs) -> None:

0 commit comments

Comments
 (0)