diff --git a/aim/storage/artifacts/artifact_storage.py b/aim/storage/artifacts/artifact_storage.py index efa73cbd19..3a50fb0ab3 100644 --- a/aim/storage/artifacts/artifact_storage.py +++ b/aim/storage/artifacts/artifact_storage.py @@ -1,7 +1,74 @@ +import os +import pathlib + from abc import abstractmethod from typing import Optional +def safe_join(prefix: str, artifact_path: str) -> pathlib.Path: + """Join ``prefix`` and ``artifact_path`` while rejecting any input that + would escape ``prefix`` via absolute paths, drive letters, or ``..`` + traversal segments. + + ``pathlib.Path('/a') / '/etc/passwd'`` yields ``/etc/passwd`` because + ``pathlib``'s ``/`` operator silently discards ``prefix`` when the right + operand is absolute. Likewise, ``pathlib.Path('/a') / '../etc/passwd'`` + yields ``/a/../etc/passwd`` which the OS resolves to ``/etc/passwd`` + when the path is later passed to ``shutil.copy``/``rmtree``/etc. + + We reject both cases up front so callers can rely on the joined path + staying inside ``prefix``. + """ + if artifact_path is None: + raise ValueError('artifact_path must not be None') + + candidate = pathlib.PurePosixPath(artifact_path) + if candidate.is_absolute() or candidate.drive or candidate.root: + raise ValueError(f'artifact_path must be relative: {artifact_path!r}') + + if any(part == '..' for part in candidate.parts): + raise ValueError(f'artifact_path must not contain ".." segments: {artifact_path!r}') + + base = pathlib.Path(prefix).resolve() + joined = (base / artifact_path).resolve() + try: + joined.relative_to(base) + except ValueError: + raise ValueError(f'artifact_path escapes the artifact root: {artifact_path!r}') + + # Preserve the un-resolved form so symlinks inside ``prefix`` continue to + # work the same way as before; the ``relative_to`` check above only + # validates that no traversal happened. + return pathlib.Path(prefix) / artifact_path + + +def safe_join_posix(prefix: str, artifact_path: str) -> str: + """Variant of :func:`safe_join` for storage backends (e.g. S3) where the + target is a POSIX-style key string rather than a local filesystem path. + + Performs the same anti-traversal validation but does not touch the local + filesystem (no ``resolve()`` against the real FS). + """ + if artifact_path is None: + raise ValueError('artifact_path must not be None') + + candidate = pathlib.PurePosixPath(artifact_path) + if candidate.is_absolute() or candidate.drive or candidate.root: + raise ValueError(f'artifact_path must be relative: {artifact_path!r}') + + if any(part == '..' for part in candidate.parts): + raise ValueError(f'artifact_path must not contain ".." segments: {artifact_path!r}') + + if prefix: + return f'{prefix.rstrip("/")}/{artifact_path}' + return artifact_path + + +# os is imported for downstream backends that may need it; keep the symbol +# exported for backwards compatibility. +__all__ = ['AbstractArtifactStorage', 'safe_join', 'safe_join_posix', 'os'] + + class AbstractArtifactStorage: def __init__(self, url: str): self.url = url diff --git a/aim/storage/artifacts/filesystem_storage.py b/aim/storage/artifacts/filesystem_storage.py index 80adfb953c..e727562789 100644 --- a/aim/storage/artifacts/filesystem_storage.py +++ b/aim/storage/artifacts/filesystem_storage.py @@ -6,7 +6,7 @@ from typing import Optional from urllib.parse import urlparse -from .artifact_storage import AbstractArtifactStorage +from .artifact_storage import AbstractArtifactStorage, safe_join class FilesystemArtifactStorage(AbstractArtifactStorage): @@ -17,7 +17,7 @@ def __init__(self, url: str): self._prefix = path def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): - dest_path = pathlib.Path(self._prefix) / artifact_path + dest_path = safe_join(self._prefix, artifact_path) dest_dir = os.path.dirname(dest_path) os.makedirs(dest_dir, exist_ok=True) shutil.copy(file_path, dest_path) @@ -28,12 +28,12 @@ def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) else: dest_dir = pathlib.Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) - source_path = dest_path = pathlib.Path(self._prefix) / artifact_path + source_path = safe_join(self._prefix, artifact_path) dest_path = dest_dir / source_path.name shutil.copy(source_path, dest_path) return dest_path.as_posix() def delete_artifact(self, artifact_path: str): - path = pathlib.Path(self._prefix) / artifact_path + path = safe_join(self._prefix, artifact_path) shutil.rmtree(path) diff --git a/aim/storage/artifacts/s3_storage.py b/aim/storage/artifacts/s3_storage.py index d30951bb18..4c1c9f94e7 100644 --- a/aim/storage/artifacts/s3_storage.py +++ b/aim/storage/artifacts/s3_storage.py @@ -8,7 +8,7 @@ from aim.ext.cleanup import AutoClean -from .artifact_storage import AbstractArtifactStorage +from .artifact_storage import AbstractArtifactStorage, safe_join_posix class S3ArtifactsStorageAutoClean(AutoClean['S3ArtifactStorage']): @@ -37,12 +37,12 @@ def __init__(self, url: str): self._resources = S3ArtifactsStorageAutoClean(self) def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): - dest_path = pathlib.Path(self._prefix) / artifact_path + dest_key = safe_join_posix(self._prefix, artifact_path) if block: - self._client.upload_file(Filename=file_path, Bucket=self._bucket, Key=dest_path.as_posix()) + self._client.upload_file(Filename=file_path, Bucket=self._bucket, Key=dest_key) else: future = self._thread_pool.submit( - self._client.upload_file, Filename=file_path, Bucket=self._bucket, Key=dest_path.as_posix() + self._client.upload_file, Filename=file_path, Bucket=self._bucket, Key=dest_key ) future.add_done_callback(self._upload_complete) self._futures.add(future) @@ -53,15 +53,15 @@ def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) else: dest_dir = pathlib.Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) - source_path = pathlib.Path(self._prefix) / artifact_path - dest_path = dest_dir / source_path.name - self._client.download_file(Bucket=self._bucket, Key=source_path.as_posix(), Filename=dest_path.as_posix()) + source_key = safe_join_posix(self._prefix, artifact_path) + dest_path = dest_dir / pathlib.PurePosixPath(source_key).name + self._client.download_file(Bucket=self._bucket, Key=source_key, Filename=dest_path.as_posix()) return dest_path.as_posix() def delete_artifact(self, artifact_path: str): - path = pathlib.Path(self._prefix) / artifact_path - self._client.delete_object(Bucket=self._bucket, Key=path.as_posix()) + key = safe_join_posix(self._prefix, artifact_path) + self._client.delete_object(Bucket=self._bucket, Key=key) def _upload_complete(self, future): self._futures.remove(future) diff --git a/tests/storage/test_artifact_storage_safe_join.py b/tests/storage/test_artifact_storage_safe_join.py new file mode 100644 index 0000000000..35658642e7 --- /dev/null +++ b/tests/storage/test_artifact_storage_safe_join.py @@ -0,0 +1,100 @@ +import os +import pathlib +import shutil +import tempfile + +import pytest + +from aim.storage.artifacts.artifact_storage import safe_join, safe_join_posix +from aim.storage.artifacts.filesystem_storage import FilesystemArtifactStorage + + +class TestSafeJoin: + def test_relative_path_is_allowed(self): + with tempfile.TemporaryDirectory() as tmp: + joined = safe_join(tmp, 'subdir/file.txt') + assert pathlib.Path(joined).is_relative_to(pathlib.Path(tmp).resolve()) + + def test_absolute_path_is_rejected(self): + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(ValueError): + safe_join(tmp, '/etc/passwd') + + def test_dotdot_traversal_is_rejected(self): + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(ValueError): + safe_join(tmp, '../../etc/passwd') + + def test_embedded_dotdot_is_rejected(self): + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(ValueError): + safe_join(tmp, 'foo/../../etc/passwd') + + def test_safe_join_posix_relative_ok(self): + assert safe_join_posix('artifacts', 'sub/file.txt') == 'artifacts/sub/file.txt' + + def test_safe_join_posix_rejects_absolute(self): + with pytest.raises(ValueError): + safe_join_posix('artifacts', '/etc/passwd') + + def test_safe_join_posix_rejects_dotdot(self): + with pytest.raises(ValueError): + safe_join_posix('artifacts', '../escape.txt') + + +class TestFilesystemArtifactStoragePathTraversal: + @pytest.fixture + def workspace(self): + root = tempfile.mkdtemp() + artifact_root = os.path.join(root, 'artifacts') + os.makedirs(artifact_root, exist_ok=True) + + source_file = os.path.join(root, 'payload.txt') + with open(source_file, 'w') as fh: + fh.write('attacker payload') + + yield {'root': root, 'artifact_root': artifact_root, 'source_file': source_file} + + shutil.rmtree(root, ignore_errors=True) + + def test_upload_rejects_absolute_artifact_path(self, workspace, tmp_path): + storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}') + + outside_target = tmp_path / 'outside_pwn.txt' + + with pytest.raises(ValueError): + storage.upload_artifact(workspace['source_file'], str(outside_target)) + + assert not outside_target.exists(), 'absolute artifact_path must not write outside the artifact root' + + def test_upload_rejects_dotdot_artifact_path(self, workspace): + storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}') + + with pytest.raises(ValueError): + storage.upload_artifact(workspace['source_file'], '../outside_pwn.txt') + + outside_target = pathlib.Path(workspace['root']) / 'outside_pwn.txt' + assert not outside_target.exists() + + def test_delete_rejects_absolute_artifact_path(self, workspace, tmp_path): + storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}') + + # Create a directory we want to ensure is NOT touched. + protected = tmp_path / 'protected' + protected.mkdir() + (protected / 'file').write_text('do not delete me') + + with pytest.raises(ValueError): + storage.delete_artifact(str(protected)) + + assert protected.exists() + assert (protected / 'file').exists() + + def test_upload_relative_path_still_works(self, workspace): + storage = FilesystemArtifactStorage(f'file://{workspace["artifact_root"]}') + + storage.upload_artifact(workspace['source_file'], 'sub/dir/legit.txt') + + landed = pathlib.Path(workspace['artifact_root']) / 'sub' / 'dir' / 'legit.txt' + assert landed.is_file() + assert landed.read_text() == 'attacker payload'