14
14
Files ,
15
15
FileUploadResponse ,
16
16
)
17
+ from llama_stack .log import get_logger
18
+ from llama_stack .providers .utils .kvstore import KVStore
17
19
from llama_stack .providers .utils .pagination import paginate_records
18
20
19
- from .config import S3ImplConfig
21
+ from .config import S3FilesImplConfig
22
+ from .persistence import S3FilesPersistence
23
+
24
+ logger = get_logger (name = __name__ , category = "files" )
20
25
21
26
22
27
class S3FilesAdapter (Files ):
23
- def __init__ (self , config : S3ImplConfig ):
28
+ def __init__ (self , config : S3FilesImplConfig , kvstore : KVStore ):
24
29
self .config = config
25
- self .session = aioboto3 .Session (
26
- aws_access_key_id = config .aws_access_key_id ,
27
- aws_secret_access_key = config .aws_secret_access_key ,
28
- region_name = config .region_name ,
29
- )
30
+ self .session = aioboto3 .Session ()
31
+ self .persistence = S3FilesPersistence (kvstore )
30
32
31
33
async def initialize (self ):
32
34
# TODO: health check?
@@ -41,8 +43,16 @@ async def create_upload_session(
41
43
) -> FileUploadResponse :
42
44
"""Create a presigned URL for uploading a file to S3."""
43
45
try :
46
+ logger .debug (
47
+ "create_upload_session" ,
48
+ {"original_key" : key , "s3_key" : key , "bucket" : bucket , "mime_type" : mime_type , "size" : size },
49
+ )
50
+
44
51
async with self .session .client (
45
52
"s3" ,
53
+ aws_access_key_id = self .config .aws_access_key_id ,
54
+ aws_secret_access_key = self .config .aws_secret_access_key ,
55
+ region_name = self .config .region_name ,
46
56
endpoint_url = self .config .endpoint_url ,
47
57
) as s3 :
48
58
url = await s3 .generate_presigned_url (
@@ -52,47 +62,108 @@ async def create_upload_session(
52
62
"Key" : key ,
53
63
"ContentType" : mime_type ,
54
64
},
55
- ExpiresIn = 3600 , # URL expires in 1 hour
65
+ ExpiresIn = 3600 , # URL expires in 1 hour - should it be longer?
56
66
)
57
- return FileUploadResponse (
67
+ logger .debug ("Generated presigned URL" , {"url" : url })
68
+
69
+ response = FileUploadResponse (
58
70
id = f"{ bucket } /{ key } " ,
59
71
url = url ,
60
72
offset = 0 ,
61
73
size = size ,
62
74
)
75
+
76
+ # Store the session info
77
+ await self .persistence .store_upload_session (
78
+ session_info = response ,
79
+ bucket = bucket ,
80
+ key = key , # Store the original key for file reading
81
+ mime_type = mime_type ,
82
+ size = size ,
83
+ )
84
+
85
+ return response
63
86
except ClientError as e :
87
+ logger .error ("S3 ClientError in create_upload_session" , {"error" : str (e )})
64
88
raise Exception (f"Failed to create upload session: { str (e )} " ) from e
65
89
66
90
async def upload_content_to_session (
67
91
self ,
68
92
upload_id : str ,
69
93
) -> FileResponse | None :
70
94
"""Upload content to S3 using the upload session."""
71
- bucket , key = upload_id . split ( "/" , 1 )
95
+
72
96
try :
97
+ # Get the upload session info from persistence
98
+ session_info = await self .persistence .get_upload_session (upload_id )
99
+ if not session_info :
100
+ raise Exception (f"Upload session { upload_id } not found" )
101
+
102
+ logger .debug (
103
+ "upload_content_to_session" ,
104
+ {
105
+ "upload_id" : upload_id ,
106
+ "bucket" : session_info .bucket ,
107
+ "key" : session_info .key ,
108
+ "mime_type" : session_info .mime_type ,
109
+ "size" : session_info .size ,
110
+ },
111
+ )
112
+
113
+ # Read the file content
114
+ with open (session_info .key , "rb" ) as f :
115
+ content = f .read ()
116
+ logger .debug ("Read content" , {"length" : len (content )})
117
+
118
+ # Use a single S3 client for all operations
73
119
async with self .session .client (
74
120
"s3" ,
121
+ aws_access_key_id = self .config .aws_access_key_id ,
122
+ aws_secret_access_key = self .config .aws_secret_access_key ,
123
+ region_name = self .config .region_name ,
75
124
endpoint_url = self .config .endpoint_url ,
76
125
) as s3 :
77
- response = await s3 .head_object (Bucket = bucket , Key = key )
126
+ # Upload the content
127
+ await s3 .put_object (
128
+ Bucket = session_info .bucket , Key = session_info .key , Body = content , ContentType = session_info .mime_type
129
+ )
130
+ logger .debug ("Upload successful" )
131
+
132
+ # Get the file info after upload
133
+ response = await s3 .head_object (Bucket = session_info .bucket , Key = session_info .key )
134
+ logger .debug (
135
+ "File info retrieved" ,
136
+ {
137
+ "ContentType" : response .get ("ContentType" ),
138
+ "ContentLength" : response ["ContentLength" ],
139
+ "LastModified" : response ["LastModified" ],
140
+ },
141
+ )
142
+
143
+ # Generate a presigned URL for reading
78
144
url = await s3 .generate_presigned_url (
79
145
"get_object" ,
80
146
Params = {
81
- "Bucket" : bucket ,
82
- "Key" : key ,
147
+ "Bucket" : session_info . bucket ,
148
+ "Key" : session_info . key ,
83
149
},
84
150
ExpiresIn = 3600 ,
85
151
)
152
+
86
153
return FileResponse (
87
- bucket = bucket ,
88
- key = key ,
154
+ bucket = session_info . bucket ,
155
+ key = session_info . key , # Use the original key to match test expectations
89
156
mime_type = response .get ("ContentType" , "application/octet-stream" ),
90
157
url = url ,
91
158
bytes = response ["ContentLength" ],
92
159
created_at = int (response ["LastModified" ].timestamp ()),
93
160
)
94
- except ClientError :
95
- return None
161
+ except ClientError as e :
162
+ logger .error ("S3 ClientError in upload_content_to_session" , {"error" : str (e )})
163
+ raise Exception (f"Failed to upload content: { str (e )} " ) from e
164
+ finally :
165
+ # Clean up the upload session
166
+ await self .persistence .delete_upload_session (upload_id )
96
167
97
168
async def get_upload_session_info (
98
169
self ,
@@ -103,6 +174,9 @@ async def get_upload_session_info(
103
174
try :
104
175
async with self .session .client (
105
176
"s3" ,
177
+ aws_access_key_id = self .config .aws_access_key_id ,
178
+ aws_secret_access_key = self .config .aws_secret_access_key ,
179
+ region_name = self .config .region_name ,
106
180
endpoint_url = self .config .endpoint_url ,
107
181
) as s3 :
108
182
response = await s3 .head_object (Bucket = bucket , Key = key )
@@ -132,15 +206,17 @@ async def list_all_buckets(
132
206
"""List all available S3 buckets."""
133
207
134
208
try :
135
- async with self .session .client (
209
+ response = await self .session .client (
136
210
"s3" ,
211
+ aws_access_key_id = self .config .aws_access_key_id ,
212
+ aws_secret_access_key = self .config .aws_secret_access_key ,
213
+ region_name = self .config .region_name ,
137
214
endpoint_url = self .config .endpoint_url ,
138
- ) as s3 :
139
- response = await s3 .list_buckets ()
140
- buckets = [BucketResponse (name = bucket ["Name" ]) for bucket in response ["Buckets" ]]
141
- # Convert BucketResponse objects to dictionaries for pagination
142
- bucket_dicts = [bucket .model_dump () for bucket in buckets ]
143
- return paginate_records (bucket_dicts , page , size )
215
+ ).list_buckets ()
216
+ buckets = [BucketResponse (name = bucket ["Name" ]) for bucket in response ["Buckets" ]]
217
+ # Convert BucketResponse objects to dictionaries for pagination
218
+ bucket_dicts = [bucket .model_dump () for bucket in buckets ]
219
+ return paginate_records (bucket_dicts , page , size )
144
220
except ClientError as e :
145
221
raise Exception (f"Failed to list buckets: { str (e )} " ) from e
146
222
@@ -152,37 +228,45 @@ async def list_files_in_bucket(
152
228
) -> PaginatedResponse :
153
229
"""List all files in an S3 bucket."""
154
230
try :
155
- async with self .session .client (
231
+ response = await self .session .client (
156
232
"s3" ,
233
+ aws_access_key_id = self .config .aws_access_key_id ,
234
+ aws_secret_access_key = self .config .aws_secret_access_key ,
235
+ region_name = self .config .region_name ,
157
236
endpoint_url = self .config .endpoint_url ,
158
- ) as s3 :
159
- response = await s3 .list_objects_v2 (Bucket = bucket )
160
- files : list [FileResponse ] = []
161
-
162
- for obj in response .get ("Contents" , []):
163
- url = await s3 .generate_presigned_url (
164
- "get_object" ,
165
- Params = {
166
- "Bucket" : bucket ,
167
- "Key" : obj ["Key" ],
168
- },
169
- ExpiresIn = 3600 ,
170
- )
237
+ ).list_objects_v2 (Bucket = bucket )
238
+ files : list [FileResponse ] = []
171
239
172
- files .append (
173
- FileResponse (
174
- bucket = bucket ,
175
- key = obj ["Key" ],
176
- mime_type = "application/octet-stream" , # Default mime type
177
- url = url ,
178
- bytes = obj ["Size" ],
179
- created_at = int (obj ["LastModified" ].timestamp ()),
180
- )
240
+ for obj in response .get ("Contents" , []):
241
+ url = await self .session .client (
242
+ "s3" ,
243
+ aws_access_key_id = self .config .aws_access_key_id ,
244
+ aws_secret_access_key = self .config .aws_secret_access_key ,
245
+ region_name = self .config .region_name ,
246
+ endpoint_url = self .config .endpoint_url ,
247
+ ).generate_presigned_url (
248
+ "get_object" ,
249
+ Params = {
250
+ "Bucket" : bucket ,
251
+ "Key" : obj ["Key" ],
252
+ },
253
+ ExpiresIn = 3600 ,
254
+ )
255
+
256
+ files .append (
257
+ FileResponse (
258
+ bucket = bucket ,
259
+ key = obj ["Key" ],
260
+ mime_type = "application/octet-stream" , # Default mime type
261
+ url = url ,
262
+ bytes = obj ["Size" ],
263
+ created_at = int (obj ["LastModified" ].timestamp ()),
181
264
)
265
+ )
182
266
183
- # Convert FileResponse objects to dictionaries for pagination
184
- file_dicts = [file .model_dump () for file in files ]
185
- return paginate_records (file_dicts , page , size )
267
+ # Convert FileResponse objects to dictionaries for pagination
268
+ file_dicts = [file .model_dump () for file in files ]
269
+ return paginate_records (file_dicts , page , size )
186
270
except ClientError as e :
187
271
raise Exception (f"Failed to list files in bucket: { str (e )} " ) from e
188
272
@@ -195,6 +279,9 @@ async def get_file(
195
279
try :
196
280
async with self .session .client (
197
281
"s3" ,
282
+ aws_access_key_id = self .config .aws_access_key_id ,
283
+ aws_secret_access_key = self .config .aws_secret_access_key ,
284
+ region_name = self .config .region_name ,
198
285
endpoint_url = self .config .endpoint_url ,
199
286
) as s3 :
200
287
response = await s3 .head_object (Bucket = bucket , Key = key )
@@ -227,9 +314,11 @@ async def delete_file(
227
314
try :
228
315
async with self .session .client (
229
316
"s3" ,
317
+ aws_access_key_id = self .config .aws_access_key_id ,
318
+ aws_secret_access_key = self .config .aws_secret_access_key ,
319
+ region_name = self .config .region_name ,
230
320
endpoint_url = self .config .endpoint_url ,
231
321
) as s3 :
232
- # Delete the file
233
322
await s3 .delete_object (Bucket = bucket , Key = key )
234
323
except ClientError as e :
235
324
raise Exception (f"Failed to delete file: { str (e )} " ) from e
0 commit comments