-
Notifications
You must be signed in to change notification settings - Fork 44
s3 public buckets #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
s3 public buckets #144
Changes from all commits
ca7c974
25bde54
64d8649
4b237a6
b0355e9
8caa8fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import os | ||
| import unittest | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| from runai_model_streamer_s3.credentials.credentials import ( | ||
| get_credentials, | ||
| AWS_CA_BUNDLE_ENV, | ||
| RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR, | ||
| ) | ||
|
|
||
|
|
||
| def _env_without_unsigned(): | ||
| env = os.environ.copy() | ||
| env.pop(RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR, None) | ||
| env.pop(AWS_CA_BUNDLE_ENV, None) | ||
| env.pop("RUNAI_STREAMER_NO_BOTO3_SESSION", None) | ||
| return env | ||
|
|
||
|
|
||
| class TestGetCredentialsUnsigned(unittest.TestCase): | ||
| @patch("runai_model_streamer_s3.credentials.credentials.boto3") | ||
| def test_unsigned_returns_no_session(self, mock_boto3): | ||
| mock_boto3.Session.return_value._session.get_config_variable.return_value = None | ||
| with patch.dict(os.environ, {RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR: "1"}, clear=False): | ||
| session, _ = get_credentials(None) | ||
| self.assertIsNone(session) | ||
|
|
||
| @patch("runai_model_streamer_s3.credentials.credentials.boto3") | ||
| def test_unsigned_sets_ca_bundle(self, mock_boto3): | ||
| mock_boto3.Session.return_value._session.get_config_variable.return_value = "/etc/ssl/custom.pem" | ||
| env = _env_without_unsigned() | ||
| env[RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR] = "1" | ||
| with patch.dict(os.environ, env, clear=True): | ||
| get_credentials(None) | ||
| self.assertEqual(os.environ.get(AWS_CA_BUNDLE_ENV), "/etc/ssl/custom.pem") | ||
|
|
||
| @patch("runai_model_streamer_s3.credentials.credentials.boto3") | ||
| def test_unsigned_disabled_resolves_credentials(self, mock_boto3): | ||
| mock_session = MagicMock() | ||
| mock_session.get_credentials.return_value = None | ||
| mock_session._session.get_config_variable.return_value = None | ||
| mock_boto3.Session.return_value = mock_session | ||
| with patch.dict(os.environ, {RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR: "0"}, clear=False): | ||
| session, _ = get_credentials(None) | ||
| mock_session.get_credentials.assert_called_once() | ||
| self.assertIsNotNone(session) | ||
|
|
||
| @patch("runai_model_streamer_s3.credentials.credentials.boto3") | ||
| def test_unsigned_absent_resolves_credentials(self, mock_boto3): | ||
| mock_session = MagicMock() | ||
| mock_session.get_credentials.return_value = None | ||
| mock_session._session.get_config_variable.return_value = None | ||
| mock_boto3.Session.return_value = mock_session | ||
| with patch.dict(os.environ, _env_without_unsigned(), clear=True): | ||
| session, _ = get_credentials(None) | ||
| mock_session.get_credentials.assert_called_once() | ||
| self.assertIsNotNone(session) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,26 +1,31 @@ | ||
| from typing import Optional, List, Tuple | ||
| from runai_model_streamer_s3.credentials.credentials import get_credentials, S3Credentials | ||
| from runai_model_streamer_s3.credentials.credentials import get_credentials, S3Credentials, RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR | ||
| import fnmatch | ||
| import os | ||
| import boto3 | ||
| from botocore import UNSIGNED | ||
| from botocore.config import Config | ||
| from pathlib import Path | ||
| import posixpath | ||
|
|
||
| def glob(path: str, allow_pattern: Optional[List[str]] = None, credentials: Optional[S3Credentials] = None) -> List[str]: | ||
| session, _ = get_credentials(credentials) | ||
| use_virtual_addressing = os.getenv("RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING", "1") | ||
|
|
||
| client_config = None | ||
| if use_virtual_addressing == "0": | ||
| client_config = Config(s3={'addressing_style': 'path'}) | ||
|
|
||
| # Pass the config to the client constructor | ||
| def _build_client_config() -> Optional[Config]: | ||
| config_kwargs = {} | ||
| if os.getenv("RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING", "1") == "0": | ||
| config_kwargs["s3"] = {"addressing_style": "path"} | ||
| if os.getenv(RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR, "0") == "1": | ||
| config_kwargs["signature_version"] = UNSIGNED | ||
| return Config(**config_kwargs) if config_kwargs else None | ||
|
|
||
| def _build_s3_client(credentials: Optional[S3Credentials]): | ||
| session, _ = get_credentials(credentials) | ||
| client_config = _build_client_config() | ||
| if session is None: | ||
| s3 = boto3.client("s3", config=client_config) | ||
| else: | ||
| s3 = session.client("s3", config=client_config) | ||
|
|
||
| return boto3.client("s3", config=client_config) | ||
| return session.client("s3", config=client_config) | ||
|
coderabbitai[bot] marked this conversation as resolved.
Comment on lines
+20
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The fix is to short-circuit before credential resolution when unsigned mode is on: def _build_s3_client(credentials: Optional[S3Credentials]):
client_config = _build_client_config()
if os.getenv(RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR, "0") == "1":
return boto3.client("s3", config=client_config)
session, _ = get_credentials(credentials)
if session is None:
return boto3.client("s3", config=client_config)
return session.client("s3", config=client_config)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_credentials checks the unsigmned flag internally, and if set ignores credentials resolution |
||
|
|
||
| def glob(path: str, allow_pattern: Optional[List[str]] = None, credentials: Optional[S3Credentials] = None) -> List[str]: | ||
| s3 = _build_s3_client(credentials) | ||
| if not path.endswith("/"): | ||
| path = f"{path}/" | ||
| bucket_name, _, keys = list_files(s3, | ||
|
|
@@ -33,18 +38,7 @@ def pull_files(model_path: str, | |
| allow_pattern: Optional[List[str]] = None, | ||
| ignore_pattern: Optional[List[str]] = None, | ||
| credentials: Optional[S3Credentials] = None,) -> None: | ||
| session, _ = get_credentials(credentials) | ||
| use_virtual_addressing = os.getenv("RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING", "1") | ||
|
|
||
| client_config = None | ||
| if use_virtual_addressing == "0": | ||
| client_config = Config(s3={'addressing_style': 'path'}) | ||
|
|
||
| # Pass the config to the client constructor | ||
| if session is None: | ||
| s3 = boto3.client("s3", config=client_config) | ||
| else: | ||
| s3 = session.client("s3", config=client_config) | ||
| s3 = _build_s3_client(credentials) | ||
|
|
||
| if not model_path.endswith("/"): | ||
| model_path = model_path + "/" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,25 @@ | ||
| import json | ||
| import shutil | ||
| import tempfile | ||
| import unittest | ||
| import os | ||
| import time | ||
| import boto3 | ||
| from unittest.mock import patch | ||
|
|
||
| from botocore.exceptions import NoCredentialsError, ClientError | ||
| from safetensors.torch import safe_open | ||
|
|
||
| from tests.cases.interface import ObjectStoreBackend | ||
| from tests.cases.testcases import compatibility_test_cases | ||
| from tests.safetensors.generator import create_random_safetensors | ||
| from tests.safetensors.comparison import tensor_maps_are_equal | ||
| from runai_model_streamer.safetensors_streamer.safetensors_streamer import ( | ||
| SafetensorsStreamer, | ||
| list_safetensors, | ||
| pull_files, | ||
| ) | ||
| RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR = "RUNAI_STREAMER_S3_UNSIGNED" | ||
|
|
||
|
|
||
| class MinioServer(ObjectStoreBackend): | ||
|
|
@@ -48,5 +61,82 @@ def upload_file(self, bucket, directory, file): | |
| bucket_name = os.getenv("AWS_BUCKET") | ||
| ) | ||
|
|
||
|
|
||
| class TestS3UnsignedPublicBucket(unittest.TestCase): | ||
| PUBLIC_BUCKET = "public-test-bucket" | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.server = MinioServer() | ||
| cls.server.wait_for_startup() | ||
| cls.temp_dir = tempfile.mkdtemp() | ||
|
|
||
| s3_admin = boto3.client( | ||
| "s3", | ||
| endpoint_url=cls.server.url, | ||
| aws_access_key_id=cls.server.key, | ||
| aws_secret_access_key=cls.server.password, | ||
| ) | ||
| try: | ||
| s3_admin.create_bucket(Bucket=cls.PUBLIC_BUCKET) | ||
| except ClientError as e: | ||
| if e.response["Error"]["Code"] != "BucketAlreadyOwnedByYou": | ||
| raise | ||
|
|
||
| s3_admin.put_bucket_policy( | ||
| Bucket=cls.PUBLIC_BUCKET, | ||
| Policy=json.dumps({ | ||
| "Version": "2012-10-17", | ||
| "Statement": [{ | ||
| "Effect": "Allow", | ||
| "Principal": "*", | ||
| "Action": ["s3:GetObject", "s3:ListBucket"], | ||
| "Resource": [ | ||
| f"arn:aws:s3:::{cls.PUBLIC_BUCKET}", | ||
| f"arn:aws:s3:::{cls.PUBLIC_BUCKET}/*" | ||
| ] | ||
| }] | ||
| }) | ||
| ) | ||
|
|
||
| cls.file_path = create_random_safetensors(cls.temp_dir) | ||
| cls.server.upload_file(cls.PUBLIC_BUCKET, "", cls.file_path) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| shutil.rmtree(cls.temp_dir) | ||
|
|
||
| def test_list_safetensors_unsigned(self): | ||
| with patch.dict(os.environ, {RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR: "1"}): | ||
| result = list_safetensors(f"s3://{self.PUBLIC_BUCKET}/") | ||
| self.assertIn(f"s3://{self.PUBLIC_BUCKET}/model.safetensors", result) | ||
|
Comment on lines
+109
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The test suite verifies the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test added |
||
|
|
||
| def test_pull_files_unsigned(self): | ||
| pull_dir = tempfile.mkdtemp() | ||
| try: | ||
| with patch.dict(os.environ, {RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR: "1"}): | ||
| pull_files(f"s3://{self.PUBLIC_BUCKET}/", pull_dir, allow_pattern=["*.safetensors"]) | ||
| self.assertIn("model.safetensors", os.listdir(pull_dir)) | ||
| finally: | ||
| shutil.rmtree(pull_dir) | ||
|
|
||
| def test_stream_file_unsigned(self): | ||
| our = {} | ||
| with patch.dict(os.environ, {RUNAI_STREAMER_S3_UNSIGNED_ENV_VAR: "1"}): | ||
| with SafetensorsStreamer() as streamer: | ||
| streamer.stream_file(f"s3://{self.PUBLIC_BUCKET}/model.safetensors", None, "cpu") | ||
| for name, tensor in streamer.get_tensors(): | ||
| our[name] = tensor | ||
|
|
||
| their = {} | ||
| with safe_open(self.file_path, framework="pt", device="cpu") as f: | ||
| for name in f.keys(): | ||
| their[name] = f.get_tensor(name) | ||
|
|
||
| equal, message = tensor_maps_are_equal(our, their) | ||
| if not equal: | ||
| self.fail(f"Tensor mismatch: {message}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.