Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions streaming/base/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@

# Default download timeout
DEFAULT_TIMEOUT = 60.0

# Maximum prefix integers
MAX_PREFIX_INT = 1000
8 changes: 7 additions & 1 deletion streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import numpy as np
from torch import distributed as dist

from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK
from streaming.base.constant import (BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, MAX_PREFIX_INT,
SHM_TO_CLEAN, TICK)
from streaming.base.shared import SharedMemory
from streaming.base.world import World

Expand Down Expand Up @@ -113,6 +114,11 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No

for prefix_int in _each_prefix_int():

if prefix_int >= MAX_PREFIX_INT:
raise ValueError(f'prefix_int exceeds {MAX_PREFIX_INT}. This may happen ' +
f'when you mock os.path.exists or os.stat so the filelock ' +
f'checks always returns ``True`` ' + f'you need to clean up TMPDIR.')

name = _get_path(prefix_int, shm_name)

# Check if any shared memory filelocks exist for the current prefix
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,13 @@ def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock):
with patch('os.path.exists', return_value=False):
next_prefix = _check_and_find(['local'], [None], LOCALS)
assert next_prefix == 1


@pytest.mark.usefixtures('local_remote_dir')
def test_shared_memory_infinity_exception(local_remote_dir: tuple[str, str]):
local, remote = local_remote_dir
with patch('os.path.exists', return_value=True):
with pytest.raises(ValueError, match='prefix_int exceeds .*clean up TMPDIR.'):
_, _ = get_shm_prefix(streams_local=[local],
streams_remote=[remote],
world=World.detect())