Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions aim/storage/artifacts/artifact_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,74 @@
import os
import pathlib

from abc import abstractmethod
from typing import Optional


def safe_join(prefix: str, artifact_path: str) -> pathlib.Path:
"""Join ``prefix`` and ``artifact_path`` while rejecting any input that
would escape ``prefix`` via absolute paths, drive letters, or ``..``
traversal segments.

``pathlib.Path('/a') / '/etc/passwd'`` yields ``/etc/passwd`` because
``pathlib``'s ``/`` operator silently discards ``prefix`` when the right
operand is absolute. Likewise, ``pathlib.Path('/a') / '../etc/passwd'``
yields ``/a/../etc/passwd`` which the OS resolves to ``/etc/passwd``
when the path is later passed to ``shutil.copy``/``rmtree``/etc.

We reject both cases up front so callers can rely on the joined path
staying inside ``prefix``.
"""
if artifact_path is None:
raise ValueError('artifact_path must not be None')

candidate = pathlib.PurePosixPath(artifact_path)
if candidate.is_absolute() or candidate.drive or candidate.root:
raise ValueError(f'artifact_path must be relative: {artifact_path!r}')

if any(part == '..' for part in candidate.parts):
raise ValueError(f'artifact_path must not contain ".." segments: {artifact_path!r}')

base = pathlib.Path(prefix).resolve()
joined = (base / artifact_path).resolve()
try:
joined.relative_to(base)
except ValueError:
raise ValueError(f'artifact_path escapes the artifact root: {artifact_path!r}')

# Preserve the un-resolved form so symlinks inside ``prefix`` continue to
# work the same way as before; the ``relative_to`` check above only
# validates that no traversal happened.
return pathlib.Path(prefix) / artifact_path


def safe_join_posix(prefix: str, artifact_path: str) -> str:
"""Variant of :func:`safe_join` for storage backends (e.g. S3) where the
target is a POSIX-style key string rather than a local filesystem path.

Performs the same anti-traversal validation but does not touch the local
filesystem (no ``resolve()`` against the real FS).
"""
if artifact_path is None:
raise ValueError('artifact_path must not be None')

candidate = pathlib.PurePosixPath(artifact_path)
if candidate.is_absolute() or candidate.drive or candidate.root:
raise ValueError(f'artifact_path must be relative: {artifact_path!r}')

if any(part == '..' for part in candidate.parts):
raise ValueError(f'artifact_path must not contain ".." segments: {artifact_path!r}')

if prefix:
return f'{prefix.rstrip("/")}/{artifact_path}'
return artifact_path


# os is imported for downstream backends that may need it; keep the symbol
# exported for backwards compatibility.
__all__ = ['AbstractArtifactStorage', 'safe_join', 'safe_join_posix', 'os']


class AbstractArtifactStorage:
def __init__(self, url: str):
self.url = url
Expand Down
8 changes: 4 additions & 4 deletions aim/storage/artifacts/filesystem_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional
from urllib.parse import urlparse

from .artifact_storage import AbstractArtifactStorage
from .artifact_storage import AbstractArtifactStorage, safe_join


class FilesystemArtifactStorage(AbstractArtifactStorage):
Expand All @@ -17,7 +17,7 @@ def __init__(self, url: str):
self._prefix = path

def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False):
dest_path = pathlib.Path(self._prefix) / artifact_path
dest_path = safe_join(self._prefix, artifact_path)
dest_dir = os.path.dirname(dest_path)
os.makedirs(dest_dir, exist_ok=True)
shutil.copy(file_path, dest_path)
Expand All @@ -28,12 +28,12 @@ def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None)
else:
dest_dir = pathlib.Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)
source_path = dest_path = pathlib.Path(self._prefix) / artifact_path
source_path = safe_join(self._prefix, artifact_path)
dest_path = dest_dir / source_path.name
shutil.copy(source_path, dest_path)

return dest_path.as_posix()

def delete_artifact(self, artifact_path: str):
path = pathlib.Path(self._prefix) / artifact_path
path = safe_join(self._prefix, artifact_path)
shutil.rmtree(path)
18 changes: 9 additions & 9 deletions aim/storage/artifacts/s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aim.ext.cleanup import AutoClean

from .artifact_storage import AbstractArtifactStorage
from .artifact_storage import AbstractArtifactStorage, safe_join_posix


class S3ArtifactsStorageAutoClean(AutoClean['S3ArtifactStorage']):
Expand Down Expand Up @@ -37,12 +37,12 @@ def __init__(self, url: str):
self._resources = S3ArtifactsStorageAutoClean(self)

def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False):
dest_path = pathlib.Path(self._prefix) / artifact_path
dest_key = safe_join_posix(self._prefix, artifact_path)
if block:
self._client.upload_file(Filename=file_path, Bucket=self._bucket, Key=dest_path.as_posix())
self._client.upload_file(Filename=file_path, Bucket=self._bucket, Key=dest_key)
else:
future = self._thread_pool.submit(
self._client.upload_file, Filename=file_path, Bucket=self._bucket, Key=dest_path.as_posix()
self._client.upload_file, Filename=file_path, Bucket=self._bucket, Key=dest_key
)
future.add_done_callback(self._upload_complete)
self._futures.add(future)
Expand All @@ -53,15 +53,15 @@ def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None)
else:
dest_dir = pathlib.Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)
source_path = pathlib.Path(self._prefix) / artifact_path
dest_path = dest_dir / source_path.name
self._client.download_file(Bucket=self._bucket, Key=source_path.as_posix(), Filename=dest_path.as_posix())
source_key = safe_join_posix(self._prefix, artifact_path)
dest_path = dest_dir / pathlib.PurePosixPath(source_key).name
self._client.download_file(Bucket=self._bucket, Key=source_key, Filename=dest_path.as_posix())

return dest_path.as_posix()

def delete_artifact(self, artifact_path: str):
path = pathlib.Path(self._prefix) / artifact_path
self._client.delete_object(Bucket=self._bucket, Key=path.as_posix())
key = safe_join_posix(self._prefix, artifact_path)
self._client.delete_object(Bucket=self._bucket, Key=key)

def _upload_complete(self, future):
self._futures.remove(future)
Expand Down
100 changes: 100 additions & 0 deletions tests/storage/test_artifact_storage_safe_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import pathlib
import shutil
import tempfile

import pytest

from aim.storage.artifacts.artifact_storage import safe_join, safe_join_posix
from aim.storage.artifacts.filesystem_storage import FilesystemArtifactStorage


class TestSafeJoin:
def test_relative_path_is_allowed(self):
with tempfile.TemporaryDirectory() as tmp:
joined = safe_join(tmp, 'subdir/file.txt')
assert pathlib.Path(joined).is_relative_to(pathlib.Path(tmp).resolve())

def test_absolute_path_is_rejected(self):
with tempfile.TemporaryDirectory() as tmp:
with pytest.raises(ValueError):
safe_join(tmp, '/etc/passwd')

def test_dotdot_traversal_is_rejected(self):
with tempfile.TemporaryDirectory() as tmp:
with pytest.raises(ValueError):
safe_join(tmp, '../../etc/passwd')

def test_embedded_dotdot_is_rejected(self):
with tempfile.TemporaryDirectory() as tmp:
with pytest.raises(ValueError):
safe_join(tmp, 'foo/../../etc/passwd')

def test_safe_join_posix_relative_ok(self):
assert safe_join_posix('artifacts', 'sub/file.txt') == 'artifacts/sub/file.txt'

def test_safe_join_posix_rejects_absolute(self):
with pytest.raises(ValueError):
safe_join_posix('artifacts', '/etc/passwd')

def test_safe_join_posix_rejects_dotdot(self):
with pytest.raises(ValueError):
safe_join_posix('artifacts', '../escape.txt')


class TestFilesystemArtifactStoragePathTraversal:
@pytest.fixture
def workspace(self):
root = tempfile.mkdtemp()
artifact_root = os.path.join(root, 'artifacts')
os.makedirs(artifact_root, exist_ok=True)

source_file = os.path.join(root, 'payload.txt')
with open(source_file, 'w') as fh:
fh.write('attacker payload')

yield {'root': root, 'artifact_root': artifact_root, 'source_file': source_file}

shutil.rmtree(root, ignore_errors=True)

def test_upload_rejects_absolute_artifact_path(self, workspace, tmp_path):
storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}')

outside_target = tmp_path / 'outside_pwn.txt'

with pytest.raises(ValueError):
storage.upload_artifact(workspace['source_file'], str(outside_target))

assert not outside_target.exists(), 'absolute artifact_path must not write outside the artifact root'

def test_upload_rejects_dotdot_artifact_path(self, workspace):
storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}')

with pytest.raises(ValueError):
storage.upload_artifact(workspace['source_file'], '../outside_pwn.txt')

outside_target = pathlib.Path(workspace['root']) / 'outside_pwn.txt'
assert not outside_target.exists()

def test_delete_rejects_absolute_artifact_path(self, workspace, tmp_path):
storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}')

# Create a directory we want to ensure is NOT touched.
protected = tmp_path / 'protected'
protected.mkdir()
(protected / 'file').write_text('do not delete me')

with pytest.raises(ValueError):
storage.delete_artifact(str(protected))

assert protected.exists()
assert (protected / 'file').exists()

def test_upload_relative_path_still_works(self, workspace):
storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}')

storage.upload_artifact(workspace['source_file'], 'sub/dir/legit.txt')

landed = pathlib.Path(workspace['artifact_root']) / 'sub' / 'dir' / 'legit.txt'
assert landed.is_file()
assert landed.read_text() == 'attacker payload'