Skip to content

Commit 07cb11c

Browse files
committed
wip
1 parent 749cbcc commit 07cb11c

File tree

9 files changed

+495
-59
lines changed

9 files changed

+495
-59
lines changed

llama_stack/providers/registry/files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ProviderSpec,
1212
remote_provider_spec,
1313
)
14+
from llama_stack.providers.utils.kvstore import kvstore_dependencies
1415

1516

1617
def available_providers() -> list[ProviderSpec]:
@@ -19,7 +20,7 @@ def available_providers() -> list[ProviderSpec]:
1920
api=Api.files,
2021
adapter=AdapterSpec(
2122
adapter_type="s3",
22-
pip_packages=["aioboto3"],
23+
pip_packages=["aioboto3"] + kvstore_dependencies(),
2324
module="llama_stack.providers.remote.files.object.s3",
2425
config_class="llama_stack.providers.remote.files.object.s3.config.S3FilesImplConfig",
2526
provider_data_validator="llama_stack.providers.remote.files.object.s3.S3ProviderDataValidator",

llama_stack/providers/remote/files/object/s3/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
async def get_adapter_impl(config: S3FilesImplConfig, _deps):
1111
from .s3_files import S3FilesAdapter
1212

13-
impl = S3FilesAdapter(config)
13+
impl = S3FilesAdapter(
14+
config,
15+
_deps,
16+
)
1417
await impl.initialize()
1518
return impl

llama_stack/providers/remote/files/object/s3/persistence.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class UploadSessionInfo(BaseModel):
2121

2222
upload_id: str
2323
bucket: str
24-
key: str
24+
key: str # Original key for file reading
25+
s3_key: str # S3 key for S3 operations
2526
mime_type: str
2627
size: int
2728
url: str
@@ -31,12 +32,12 @@ class UploadSessionInfo(BaseModel):
3132
class S3FilesPersistence:
3233
def __init__(self, kvstore: KVStore):
3334
self._kvstore = kvstore
34-
self._store = None
35+
self._store: KVStore | None = None
3536

3637
async def _get_store(self) -> KVStore:
3738
"""Get the kvstore instance, initializing it if needed."""
3839
if self._store is None:
39-
self._store = await anext(self._kvstore)
40+
self._store = self._kvstore
4041
return self._store
4142

4243
async def store_upload_session(
@@ -47,6 +48,7 @@ async def store_upload_session(
4748
upload_id=session_info.id,
4849
bucket=bucket,
4950
key=key,
51+
s3_key=key,
5052
mime_type=mime_type,
5153
size=size,
5254
url=session_info.url,

llama_stack/providers/remote/files/object/s3/s3_files.py

Lines changed: 140 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
Files,
1515
FileUploadResponse,
1616
)
17+
from llama_stack.log import get_logger
18+
from llama_stack.providers.utils.kvstore import KVStore
1719
from llama_stack.providers.utils.pagination import paginate_records
1820

19-
from .config import S3ImplConfig
21+
from .config import S3FilesImplConfig
22+
from .persistence import S3FilesPersistence
23+
24+
logger = get_logger(name=__name__, category="files")
2025

2126

2227
class S3FilesAdapter(Files):
23-
def __init__(self, config: S3ImplConfig):
28+
def __init__(self, config: S3FilesImplConfig, kvstore: KVStore):
2429
self.config = config
25-
self.session = aioboto3.Session(
26-
aws_access_key_id=config.aws_access_key_id,
27-
aws_secret_access_key=config.aws_secret_access_key,
28-
region_name=config.region_name,
29-
)
30+
self.session = aioboto3.Session()
31+
self.persistence = S3FilesPersistence(kvstore)
3032

3133
async def initialize(self):
3234
# TODO: health check?
@@ -41,8 +43,16 @@ async def create_upload_session(
4143
) -> FileUploadResponse:
4244
"""Create a presigned URL for uploading a file to S3."""
4345
try:
46+
logger.debug(
47+
"create_upload_session",
48+
{"original_key": key, "s3_key": key, "bucket": bucket, "mime_type": mime_type, "size": size},
49+
)
50+
4451
async with self.session.client(
4552
"s3",
53+
aws_access_key_id=self.config.aws_access_key_id,
54+
aws_secret_access_key=self.config.aws_secret_access_key,
55+
region_name=self.config.region_name,
4656
endpoint_url=self.config.endpoint_url,
4757
) as s3:
4858
url = await s3.generate_presigned_url(
@@ -52,47 +62,108 @@ async def create_upload_session(
5262
"Key": key,
5363
"ContentType": mime_type,
5464
},
55-
ExpiresIn=3600, # URL expires in 1 hour
65+
ExpiresIn=3600, # URL expires in 1 hour - should it be longer?
5666
)
57-
return FileUploadResponse(
67+
logger.debug("Generated presigned URL", {"url": url})
68+
69+
response = FileUploadResponse(
5870
id=f"{bucket}/{key}",
5971
url=url,
6072
offset=0,
6173
size=size,
6274
)
75+
76+
# Store the session info
77+
await self.persistence.store_upload_session(
78+
session_info=response,
79+
bucket=bucket,
80+
key=key, # Store the original key for file reading
81+
mime_type=mime_type,
82+
size=size,
83+
)
84+
85+
return response
6386
except ClientError as e:
87+
logger.error("S3 ClientError in create_upload_session", {"error": str(e)})
6488
raise Exception(f"Failed to create upload session: {str(e)}") from e
6589

6690
async def upload_content_to_session(
6791
self,
6892
upload_id: str,
6993
) -> FileResponse | None:
7094
"""Upload content to S3 using the upload session."""
71-
bucket, key = upload_id.split("/", 1)
95+
7296
try:
97+
# Get the upload session info from persistence
98+
session_info = await self.persistence.get_upload_session(upload_id)
99+
if not session_info:
100+
raise Exception(f"Upload session {upload_id} not found")
101+
102+
logger.debug(
103+
"upload_content_to_session",
104+
{
105+
"upload_id": upload_id,
106+
"bucket": session_info.bucket,
107+
"key": session_info.key,
108+
"mime_type": session_info.mime_type,
109+
"size": session_info.size,
110+
},
111+
)
112+
113+
# Read the file content
114+
with open(session_info.key, "rb") as f:
115+
content = f.read()
116+
logger.debug("Read content", {"length": len(content)})
117+
118+
# Use a single S3 client for all operations
73119
async with self.session.client(
74120
"s3",
121+
aws_access_key_id=self.config.aws_access_key_id,
122+
aws_secret_access_key=self.config.aws_secret_access_key,
123+
region_name=self.config.region_name,
75124
endpoint_url=self.config.endpoint_url,
76125
) as s3:
77-
response = await s3.head_object(Bucket=bucket, Key=key)
126+
# Upload the content
127+
await s3.put_object(
128+
Bucket=session_info.bucket, Key=session_info.key, Body=content, ContentType=session_info.mime_type
129+
)
130+
logger.debug("Upload successful")
131+
132+
# Get the file info after upload
133+
response = await s3.head_object(Bucket=session_info.bucket, Key=session_info.key)
134+
logger.debug(
135+
"File info retrieved",
136+
{
137+
"ContentType": response.get("ContentType"),
138+
"ContentLength": response["ContentLength"],
139+
"LastModified": response["LastModified"],
140+
},
141+
)
142+
143+
# Generate a presigned URL for reading
78144
url = await s3.generate_presigned_url(
79145
"get_object",
80146
Params={
81-
"Bucket": bucket,
82-
"Key": key,
147+
"Bucket": session_info.bucket,
148+
"Key": session_info.key,
83149
},
84150
ExpiresIn=3600,
85151
)
152+
86153
return FileResponse(
87-
bucket=bucket,
88-
key=key,
154+
bucket=session_info.bucket,
155+
key=session_info.key, # Use the original key to match test expectations
89156
mime_type=response.get("ContentType", "application/octet-stream"),
90157
url=url,
91158
bytes=response["ContentLength"],
92159
created_at=int(response["LastModified"].timestamp()),
93160
)
94-
except ClientError:
95-
return None
161+
except ClientError as e:
162+
logger.error("S3 ClientError in upload_content_to_session", {"error": str(e)})
163+
raise Exception(f"Failed to upload content: {str(e)}") from e
164+
finally:
165+
# Clean up the upload session
166+
await self.persistence.delete_upload_session(upload_id)
96167

97168
async def get_upload_session_info(
98169
self,
@@ -103,6 +174,9 @@ async def get_upload_session_info(
103174
try:
104175
async with self.session.client(
105176
"s3",
177+
aws_access_key_id=self.config.aws_access_key_id,
178+
aws_secret_access_key=self.config.aws_secret_access_key,
179+
region_name=self.config.region_name,
106180
endpoint_url=self.config.endpoint_url,
107181
) as s3:
108182
response = await s3.head_object(Bucket=bucket, Key=key)
@@ -132,15 +206,17 @@ async def list_all_buckets(
132206
"""List all available S3 buckets."""
133207

134208
try:
135-
async with self.session.client(
209+
response = await self.session.client(
136210
"s3",
211+
aws_access_key_id=self.config.aws_access_key_id,
212+
aws_secret_access_key=self.config.aws_secret_access_key,
213+
region_name=self.config.region_name,
137214
endpoint_url=self.config.endpoint_url,
138-
) as s3:
139-
response = await s3.list_buckets()
140-
buckets = [BucketResponse(name=bucket["Name"]) for bucket in response["Buckets"]]
141-
# Convert BucketResponse objects to dictionaries for pagination
142-
bucket_dicts = [bucket.model_dump() for bucket in buckets]
143-
return paginate_records(bucket_dicts, page, size)
215+
).list_buckets()
216+
buckets = [BucketResponse(name=bucket["Name"]) for bucket in response["Buckets"]]
217+
# Convert BucketResponse objects to dictionaries for pagination
218+
bucket_dicts = [bucket.model_dump() for bucket in buckets]
219+
return paginate_records(bucket_dicts, page, size)
144220
except ClientError as e:
145221
raise Exception(f"Failed to list buckets: {str(e)}") from e
146222

@@ -152,37 +228,45 @@ async def list_files_in_bucket(
152228
) -> PaginatedResponse:
153229
"""List all files in an S3 bucket."""
154230
try:
155-
async with self.session.client(
231+
response = await self.session.client(
156232
"s3",
233+
aws_access_key_id=self.config.aws_access_key_id,
234+
aws_secret_access_key=self.config.aws_secret_access_key,
235+
region_name=self.config.region_name,
157236
endpoint_url=self.config.endpoint_url,
158-
) as s3:
159-
response = await s3.list_objects_v2(Bucket=bucket)
160-
files: list[FileResponse] = []
161-
162-
for obj in response.get("Contents", []):
163-
url = await s3.generate_presigned_url(
164-
"get_object",
165-
Params={
166-
"Bucket": bucket,
167-
"Key": obj["Key"],
168-
},
169-
ExpiresIn=3600,
170-
)
237+
).list_objects_v2(Bucket=bucket)
238+
files: list[FileResponse] = []
171239

172-
files.append(
173-
FileResponse(
174-
bucket=bucket,
175-
key=obj["Key"],
176-
mime_type="application/octet-stream", # Default mime type
177-
url=url,
178-
bytes=obj["Size"],
179-
created_at=int(obj["LastModified"].timestamp()),
180-
)
240+
for obj in response.get("Contents", []):
241+
url = await self.session.client(
242+
"s3",
243+
aws_access_key_id=self.config.aws_access_key_id,
244+
aws_secret_access_key=self.config.aws_secret_access_key,
245+
region_name=self.config.region_name,
246+
endpoint_url=self.config.endpoint_url,
247+
).generate_presigned_url(
248+
"get_object",
249+
Params={
250+
"Bucket": bucket,
251+
"Key": obj["Key"],
252+
},
253+
ExpiresIn=3600,
254+
)
255+
256+
files.append(
257+
FileResponse(
258+
bucket=bucket,
259+
key=obj["Key"],
260+
mime_type="application/octet-stream", # Default mime type
261+
url=url,
262+
bytes=obj["Size"],
263+
created_at=int(obj["LastModified"].timestamp()),
181264
)
265+
)
182266

183-
# Convert FileResponse objects to dictionaries for pagination
184-
file_dicts = [file.model_dump() for file in files]
185-
return paginate_records(file_dicts, page, size)
267+
# Convert FileResponse objects to dictionaries for pagination
268+
file_dicts = [file.model_dump() for file in files]
269+
return paginate_records(file_dicts, page, size)
186270
except ClientError as e:
187271
raise Exception(f"Failed to list files in bucket: {str(e)}") from e
188272

@@ -195,6 +279,9 @@ async def get_file(
195279
try:
196280
async with self.session.client(
197281
"s3",
282+
aws_access_key_id=self.config.aws_access_key_id,
283+
aws_secret_access_key=self.config.aws_secret_access_key,
284+
region_name=self.config.region_name,
198285
endpoint_url=self.config.endpoint_url,
199286
) as s3:
200287
response = await s3.head_object(Bucket=bucket, Key=key)
@@ -227,9 +314,11 @@ async def delete_file(
227314
try:
228315
async with self.session.client(
229316
"s3",
317+
aws_access_key_id=self.config.aws_access_key_id,
318+
aws_secret_access_key=self.config.aws_secret_access_key,
319+
region_name=self.config.region_name,
230320
endpoint_url=self.config.endpoint_url,
231321
) as s3:
232-
# Delete the file
233322
await s3.delete_object(Bucket=bucket, Key=key)
234323
except ClientError as e:
235324
raise Exception(f"Failed to delete file: {str(e)}") from e

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ unit = [
6969
"chardet",
7070
"qdrant-client",
7171
"opentelemetry-exporter-otlp-proto-http",
72+
"aioboto3",
7273
]
7374
# These are the core dependencies required for running integration tests. They are shared across all
7475
# providers. If a provider requires additional dependencies, please add them to your environment

tests/integration/files/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from typing import AsyncGenerator
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from collections.abc import AsyncGenerator
28

39
import pytest
410

0 commit comments

Comments
 (0)