|
1 | 1 | import logging |
2 | 2 | import subprocess |
| 3 | +from typing import Any |
3 | 4 |
|
4 | 5 | from dmoj.checkers import CheckerOutput |
5 | 6 | from dmoj.cptbox import TracedPopen |
6 | 7 | from dmoj.cptbox.lazy_bytes import LazyBytes |
| 8 | +from dmoj.cptbox.utils import MemoryIO, MmapableIO |
7 | 9 | from dmoj.error import OutputLimitExceeded |
8 | 10 | from dmoj.executors import executors |
9 | 11 | from dmoj.executors.base_executor import BaseExecutor |
10 | 12 | from dmoj.graders.base import BaseGrader |
11 | | -from dmoj.problem import TestCase |
| 13 | +from dmoj.judge import JudgeWorker |
| 14 | +from dmoj.problem import Problem, TestCase |
12 | 15 | from dmoj.result import CheckerResult, Result |
13 | 16 |
|
14 | 17 | log = logging.getLogger('dmoj.graders') |
15 | 18 |
|
16 | 19 |
|
17 | 20 | class StandardGrader(BaseGrader): |
| 21 | + _stdout_io: MmapableIO |
| 22 | + _stderr_io: MmapableIO |
| 23 | + _orig_fsize: int |
| 24 | + memfd_output: bool = True |
| 25 | + |
| 26 | + def __init__(self, judge: 'JudgeWorker', problem: Problem, language: str, source: bytes) -> None: |
| 27 | + super().__init__(judge, problem, language, source) |
| 28 | + self._orig_fsize = self.binary.fsize |
| 29 | + |
18 | 30 | def grade(self, case: TestCase) -> Result: |
19 | 31 | result = Result(case) |
20 | 32 |
|
@@ -83,34 +95,59 @@ def check_result(self, case: TestCase, result: Result) -> CheckerOutput: |
83 | 95 | return check |
84 | 96 |
|
85 | 97 | def _launch_process(self, case: TestCase, input_file=None) -> None: |
| 98 | + stdout: Any |
| 99 | + stderr: Any |
| 100 | + |
| 101 | + if self.memfd_output: |
| 102 | + stdout = self._stdout_io = MemoryIO() |
| 103 | + stderr = self._stderr_io = MemoryIO() |
| 104 | + self.binary.fsize = max(self._orig_fsize, case.config.output_limit_length + 1024, 1048576) |
| 105 | + else: |
| 106 | + stdout = subprocess.PIPE |
| 107 | + stderr = subprocess.PIPE |
| 108 | + |
86 | 109 | self._current_proc = self.binary.launch( |
87 | 110 | time=self.problem.time_limit, |
88 | 111 | memory=self.problem.memory_limit, |
89 | 112 | symlinks=case.config.symlinks, |
90 | 113 | stdin=input_file or subprocess.PIPE, |
91 | | - stdout=subprocess.PIPE, |
92 | | - stderr=subprocess.PIPE, |
| 114 | + stdout=stdout, |
| 115 | + stderr=stderr, |
93 | 116 | wall_time=case.config.wall_time_factor * self.problem.time_limit, |
94 | 117 | ) |
95 | 118 |
|
96 | 119 | def _interact_with_process(self, case: TestCase, result: Result) -> bytes: |
97 | 120 | process = self._current_proc |
98 | 121 | assert process is not None |
99 | | - try: |
100 | | - result.proc_output, error = process.communicate( |
101 | | - None, outlimit=case.config.output_limit_length, errlimit=1048576 |
102 | | - ) |
103 | | - except OutputLimitExceeded: |
104 | | - error = b'' |
105 | | - process.kill() |
106 | | - finally: |
| 122 | + |
| 123 | + if self.memfd_output: |
107 | 124 | process.wait() |
| 125 | + |
| 126 | + result.proc_output = self._stdout_io.to_bytes() |
| 127 | + self._stdout_io.close() |
| 128 | + |
| 129 | + if len(result.proc_output) > case.config.output_limit_length: |
| 130 | + process.mark_ole() |
| 131 | + |
| 132 | + error = self._stderr_io.to_bytes() |
| 133 | + self._stderr_io.close() |
| 134 | + else: |
| 135 | + try: |
| 136 | + result.proc_output, error = process.communicate( |
| 137 | + None, outlimit=case.config.output_limit_length, errlimit=1048576 |
| 138 | + ) |
| 139 | + except OutputLimitExceeded: |
| 140 | + error = b'' |
| 141 | + process.kill() |
| 142 | + finally: |
| 143 | + process.wait() |
108 | 144 | return error |
109 | 145 |
|
110 | 146 | def _generate_binary(self) -> BaseExecutor: |
111 | | - return executors[self.language].Executor( |
| 147 | + executor = executors[self.language].Executor( |
112 | 148 | self.problem.id, |
113 | 149 | self.source, |
114 | 150 | hints=self.problem.config.hints or [], |
115 | 151 | unbuffered=self.problem.config.unbuffered, |
116 | 152 | ) |
| 153 | + return executor |
0 commit comments