Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions dmoj/checkers/bridged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess

from dmoj.contrib import contrib_modules
from dmoj.cptbox.filesystem_policies import ExactFile
from dmoj.error import InternalError
from dmoj.judgeenv import env, get_problem_root
from dmoj.result import CheckerResult
Expand All @@ -25,10 +26,10 @@ def get_executor(problem_id, files, flags, lang, compiler_time_limit):
def check(
process_output,
judge_output,
judge_input,
problem_id,
files,
lang,
case,
time_limit=env['generator_time_limit'],
memory_limit=env['generator_memory_limit'],
compiler_time_limit=env['generator_compiler_limit'],
Expand All @@ -46,16 +47,23 @@ def check(

args_format_string = args_format_string or contrib_modules[type].ContribModule.get_checker_args_format_string()

with mktemp(judge_input) as input_file, mktemp(process_output) as output_file, mktemp(judge_output) as answer_file:
with mktemp(process_output) as output_file, mktemp(judge_output) as answer_file:
input_path = case.input_data_io().to_path()

checker_args = shlex.split(
args_format_string.format(
input_file=shlex.quote(input_file.name),
input_file=shlex.quote(input_path),
output_file=shlex.quote(output_file.name),
answer_file=shlex.quote(answer_file.name),
)
)
process = executor.launch(
*checker_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, memory=memory_limit, time=time_limit
*checker_args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
memory=memory_limit,
time=time_limit,
extra_fs=[ExactFile(input_path)],
)

proc_output, error = process.communicate()
Expand Down
7 changes: 5 additions & 2 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,8 @@ AT_FDCWD: int
bsd_get_proc_cwd: Callable[[int], str]
bsd_get_proc_fdno: Callable[[int, int], str]

memory_fd_create: Callable[[], int]
memory_fd_seal: Callable[[int], None]
memfd_create: Callable[[], int]
memfd_seal: Callable[[int], None]

class BufferProxy:
def _get_real_buffer(self): ...
21 changes: 15 additions & 6 deletions dmoj/cptbox/_cptbox.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# cython: language_level=3
from cpython.exc cimport PyErr_NoMemory, PyErr_SetFromErrno
from cpython.buffer cimport PyObject_GetBuffer
from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize
from libc.stdio cimport FILE, fopen, fclose, fgets, sprintf
from libc.stdlib cimport malloc, free, strtoul
Expand Down Expand Up @@ -133,8 +134,8 @@ cdef extern from 'helper.h' nogil:
PTBOX_SPAWN_FAIL_EXECVE
PTBOX_SPAWN_FAIL_SETAFFINITY

int _memory_fd_create "memory_fd_create"()
int _memory_fd_seal "memory_fd_seal"(int fd)
int cptbox_memfd_create()
int cptbox_memfd_seal(int fd)


cdef extern from 'fcntl.h' nogil:
Expand Down Expand Up @@ -214,14 +215,14 @@ def bsd_get_proc_fdno(pid_t pid, int fd):
free(buf)
return res

def memory_fd_create():
cdef int fd = _memory_fd_create()
def memfd_create():
cdef int fd = cptbox_memfd_create()
if fd < 0:
PyErr_SetFromErrno(OSError)
return fd

def memory_fd_seal(int fd):
cdef int result = _memory_fd_seal(fd)
def memfd_seal(int fd):
cdef int result = cptbox_memfd_seal(fd)
if result == -1:
PyErr_SetFromErrno(OSError)

Expand Down Expand Up @@ -600,3 +601,11 @@ cdef class Process:
if not self._exited:
return None
return self._exitcode


cdef class BufferProxy:
def _get_real_buffer(self):
raise NotImplementedError

def __getbuffer__(self, Py_buffer *buffer, int flags):
PyObject_GetBuffer(self._get_real_buffer(), buffer, flags)
11 changes: 4 additions & 7 deletions dmoj/cptbox/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,16 @@ char *bsd_get_proc_fdno(pid_t pid, int fdno) {
return bsd_get_proc_fd(pid, 0, fdno);
}

int memory_fd_create(void) {
int cptbox_memfd_create(void) {
#ifdef __FreeBSD__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is this function called on FreeBSD anymore? Are you creating the tempfile in Python instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this around for now, since I'd rather this function work on all platforms, as the detection logic for the FreeBSD case is now different. If FreeBSD implements /proc/[pid]/fd some day, this will magically work.

char filename[] = "/tmp/cptbox-memoryfd-XXXXXXXX";
int fd = mkstemp(filename);
if (fd > 0)
unlink(filename);
return fd;
errno = ENOSYS;
return -1;
#else
return memfd_create("cptbox memory_fd", MFD_ALLOW_SEALING);
#endif
}

int memory_fd_seal(int fd) {
int cptbox_memfd_seal(int fd) {
#ifdef __FreeBSD__
errno = ENOSYS;
return -1;
Expand Down
4 changes: 2 additions & 2 deletions dmoj/cptbox/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int cptbox_child_run(const struct child_config *config);
char *bsd_get_proc_cwd(pid_t pid);
char *bsd_get_proc_fdno(pid_t pid, int fdno);

int memory_fd_create(void);
int memory_fd_seal(int fd);
int cptbox_memfd_create(void);
int cptbox_memfd_seal(int fd);

#endif
4 changes: 2 additions & 2 deletions dmoj/cptbox/isolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _access_check(self, debugger: Debugger, file: str, fs_jail: FilesystemPolicy
real = os.path.realpath(file)

try:
same = normalized == real or os.path.samefile(projected, real)
same = normalized == real or real.startswith('/memfd:') or os.path.samefile(projected, real)
except OSError:
raise DeniedSyscall(ACCESS_ENOENT, f'Cannot stat, file: {file}, projected: {projected}, real: {real}')

Expand All @@ -385,7 +385,7 @@ def _access_check(self, debugger: Debugger, file: str, fs_jail: FilesystemPolicy
else:
real = os.path.join('/proc/self', relpath)

if not fs_jail.check(real):
if not real.startswith('/memfd:') and not fs_jail.check(real):
raise DeniedSyscall(ACCESS_EACCES, f'Denying {file}, real path {real}')

def handle_kill(self, debugger: Debugger) -> None:
Expand Down
88 changes: 88 additions & 0 deletions dmoj/cptbox/lazy_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Based off https://github.com/django/django/blob/main/django/utils/functional.py, licensed under 3-clause BSD.
from functools import total_ordering

from dmoj.cptbox._cptbox import BufferProxy

_SENTINEL = object()


@total_ordering
class LazyBytes(BufferProxy):
"""
Encapsulate a function call and act as a proxy for methods that are
called on the result of that function. The function is not evaluated
until one of the methods on the result is called.
"""

def __init__(self, func):
self.__func = func
self.__value = _SENTINEL

def __get_value(self):
if self.__value is _SENTINEL:
self.__value = self.__func()
return self.__value

@classmethod
def _create_promise(cls, method_name):
# Builds a wrapper around some magic method
def wrapper(self, *args, **kw):
# Automatically triggers the evaluation of a lazy value and
# applies the given magic method of the result type.
res = self.__get_value()
return getattr(res, method_name)(*args, **kw)

return wrapper

def __cast(self):
return bytes(self.__get_value())

def _get_real_buffer(self):
return self.__cast()

def __bytes__(self):
return self.__cast()

def __repr__(self):
return repr(self.__cast())

def __str__(self):
return str(self.__cast())

def __eq__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() == other

def __lt__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() < other

def __hash__(self):
return hash(self.__cast())

def __mod__(self, rhs):
return self.__cast() % rhs

def __add__(self, other):
return self.__cast() + other

def __radd__(self, other):
return other + self.__cast()

def __deepcopy__(self, memo):
# Instances of this class are effectively immutable. It's just a
# collection of functions. So we don't need to do anything
# complicated for copying.
memo[id(self)] = self
return self


for type_ in bytes.mro():
for method_name in type_.__dict__:
# All __promise__ return the same wrapper method, they
# look up the correct implementation when called.
if hasattr(LazyBytes, method_name):
continue
setattr(LazyBytes, method_name, LazyBytes._create_promise(method_name))
115 changes: 110 additions & 5 deletions dmoj/cptbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,116 @@
import io
import mmap
import os
from abc import ABCMeta, abstractmethod
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Optional

from dmoj.cptbox._cptbox import memory_fd_create, memory_fd_seal
from dmoj.cptbox._cptbox import memfd_create, memfd_seal


class MemoryIO(io.FileIO):
def __init__(self) -> None:
super().__init__(memory_fd_create(), 'r+')
def _make_fd_readonly(fd):
new_fd = os.open(f'/proc/self/fd/{fd}', os.O_RDONLY)
try:
os.dup2(new_fd, fd)
finally:
os.close(new_fd)


class MmapableIO(io.FileIO, metaclass=ABCMeta):
def __init__(self, fd, *, prefill: Optional[bytes] = None, seal=False) -> None:
super().__init__(fd, 'r+')

if prefill:
self.write(prefill)
if seal:
self.seal()

@classmethod
@abstractmethod
def usable_with_name(cls) -> bool:
...

@abstractmethod
def seal(self) -> None:
...

@abstractmethod
def to_path(self) -> str:
...

def to_bytes(self) -> bytes:
try:
with mmap.mmap(self.fileno(), 0, access=mmap.ACCESS_READ) as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often do we expect this will be called? Should we madvise(..., MADV_SEQUENTIAL) here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very often, it's mostly for compatibility with old checkers etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like "very often" to me, but happy to punt on this. I worry we'll hit issues with gigabyte-sized generator inputs that also have checkers, since this doubles the memory requirement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's less of a problem than it looks. In the standard grader, we pass this magic to checkers: judge_input=LazyBytes(case.input_data). We only pay for this if the checker actually reads judge_input.

return bytes(f)
except ValueError as e:
if e.args[0] == 'cannot mmap an empty file':
return b''
raise


class NamedFileIO(MmapableIO):
_name: str

def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
with NamedTemporaryFile(delete=False) as f:
self._name = f.name
super().__init__(os.dup(f.fileno()), prefill=prefill, seal=seal)

def seal(self) -> None:
self.seek(0, os.SEEK_SET)

def close(self) -> None:
super().close()
os.unlink(self._name)

def to_path(self) -> str:
return self._name

@classmethod
def usable_with_name(cls):
return True


class UnnamedFileIO(MmapableIO):
def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
with TemporaryFile() as f:
super().__init__(os.dup(f.fileno()), prefill=prefill, seal=seal)

def seal(self) -> None:
self.seek(0, os.SEEK_SET)
_make_fd_readonly(self.fileno())

def to_path(self) -> str:
return f'/proc/{os.getpid()}/fd/{self.fileno()}'

@classmethod
def usable_with_name(cls):
with cls() as f:
return os.path.exists(f.to_path())


class MemfdIO(MmapableIO):
def __init__(self, *, prefill: Optional[bytes] = None, seal=False) -> None:
super().__init__(memfd_create(), prefill=prefill, seal=seal)

def seal(self) -> None:
memory_fd_seal(self.fileno())
fd = self.fileno()
memfd_seal(fd)
_make_fd_readonly(fd)

def to_path(self) -> str:
return f'/proc/{os.getpid()}/fd/{self.fileno()}'

@classmethod
def usable_with_name(cls):
try:
with cls() as f:
return os.path.exists(f.to_path())
except OSError:
return False


# Try to use memfd if possible, otherwise fallback to unlinked temporary files
# (UnnamedFileIO). On FreeBSD and some other systems, /proc/[pid]/fd doesn't
# exist, so to_path() will not work. We fall back to NamedFileIO in that case.
MemoryIO = next((i for i in (MemfdIO, UnnamedFileIO, NamedFileIO) if i.usable_with_name()))
9 changes: 6 additions & 3 deletions dmoj/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,11 @@ def _add_syscalls(self, sec: IsolateTracer, handlers: List[Union[str, Tuple[str,
sec[getattr(syscalls, f'sys_{name}')] = handler
return sec

def get_security(self, launch_kwargs=None) -> IsolateTracer:
sec = IsolateTracer(read_fs=self.get_fs(), write_fs=self.get_write_fs())
def get_security(self, launch_kwargs=None, extra_fs=None) -> IsolateTracer:
read_fs = self.get_fs()
if extra_fs:
read_fs += extra_fs
sec = IsolateTracer(read_fs=read_fs, write_fs=self.get_write_fs())
return self._add_syscalls(sec, self.get_allowed_syscalls())

def get_fs(self) -> List[FilesystemAccessRule]:
Expand Down Expand Up @@ -299,7 +302,7 @@ def launch(self, *args, **kwargs) -> TracedPopen:
return TracedPopen(
[utf8bytes(a) for a in self.get_cmdline(**kwargs) + list(args)],
executable=utf8bytes(executable),
security=self.get_security(launch_kwargs=kwargs),
security=self.get_security(launch_kwargs=kwargs, extra_fs=kwargs.get('extra_fs')),
address_grace=self.get_address_grace(),
data_grace=self.data_grace,
personality=self.personality,
Expand Down
Loading
Loading