4
4
import base64
5
5
import json
6
6
import uuid
7
- from abc import ABC , abstractmethod
8
7
from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Union , cast
9
8
10
9
from litellm import Router , verbose_logger
11
10
from litellm .caching .caching import DualCache
12
11
from litellm .integrations .custom_logger import CustomLogger
13
12
from litellm .litellm_core_utils .prompt_templates .common_utils import extract_file_data
13
+ from litellm .llms .base_llm .files .transformation import BaseFileEndpoints
14
14
from litellm .proxy ._types import CallTypes , LiteLLM_ManagedFileTable , UserAPIKeyAuth
15
+ from litellm .proxy .openai_files_endpoints .common_utils import (
16
+ _is_base64_encoded_unified_file_id ,
17
+ convert_b64_uid_to_unified_uid ,
18
+ )
15
19
from litellm .types .llms .openai import (
16
20
AllMessageValues ,
17
21
ChatCompletionFileObject ,
36
40
PrismaClient = Any
37
41
38
42
39
- class BaseFileEndpoints (ABC ):
40
- @abstractmethod
41
- async def afile_retrieve (
42
- self ,
43
- file_id : str ,
44
- litellm_parent_otel_span : Optional [Span ],
45
- ) -> OpenAIFileObject :
46
- pass
47
-
48
- @abstractmethod
49
- async def afile_list (
50
- self , custom_llm_provider : str , ** data : dict
51
- ) -> List [OpenAIFileObject ]:
52
- pass
53
-
54
- @abstractmethod
55
- async def afile_delete (
56
- self , custom_llm_provider : str , file_id : str , ** data : dict
57
- ) -> OpenAIFileObject :
58
- pass
59
-
60
-
61
- class _PROXY_LiteLLMManagedFiles (CustomLogger ):
43
+ class _PROXY_LiteLLMManagedFiles (CustomLogger , BaseFileEndpoints ):
62
44
# Class variables or attributes
63
45
def __init__ (
64
46
self , internal_usage_cache : InternalUsageCache , prisma_client : PrismaClient
@@ -153,12 +135,14 @@ async def async_pre_call_hook(
153
135
"audio_transcription" ,
154
136
"pass_through_endpoint" ,
155
137
"rerank" ,
138
+ "acreate_batch" ,
156
139
],
157
140
) -> Union [Exception , str , Dict , None ]:
158
141
"""
159
142
- Detect litellm_proxy/ file_id
160
143
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
161
144
"""
145
+ print ("REACHES async_pre_call_hook, call_type:" , call_type )
162
146
if call_type == CallTypes .completion .value :
163
147
messages = data .get ("messages" )
164
148
if messages :
@@ -169,9 +153,37 @@ async def async_pre_call_hook(
169
153
)
170
154
171
155
data ["model_file_id_mapping" ] = model_file_id_mapping
156
+ elif call_type == CallTypes .acreate_batch .value :
157
+ input_file_id = cast (Optional [str ], data .get ("input_file_id" ))
158
+ if input_file_id :
159
+ model_file_id_mapping = await self .get_model_file_id_mapping (
160
+ [input_file_id ], user_api_key_dict .parent_otel_span
161
+ )
172
162
163
+ data ["model_file_id_mapping" ] = model_file_id_mapping
173
164
return data
174
165
166
+ async def async_pre_call_deployment_hook (
167
+ self , kwargs : Dict [str , Any ], call_type : Optional [CallTypes ]
168
+ ) -> Optional [dict ]:
169
+ """
170
+ Allow modifying the request just before it's sent to the deployment.
171
+ """
172
+ if call_type and call_type == CallTypes .acreate_batch :
173
+ input_file_id = cast (Optional [str ], kwargs .get ("input_file_id" ))
174
+ model_file_id_mapping = cast (
175
+ Optional [Dict [str , Dict [str , str ]]], kwargs .get ("model_file_id_mapping" )
176
+ )
177
+ model_id = cast (Optional [str ], kwargs .get ("model_info" , {}).get ("id" , None ))
178
+ mapped_file_id : Optional [str ] = None
179
+ if input_file_id and model_file_id_mapping and model_id :
180
+ mapped_file_id = model_file_id_mapping .get (input_file_id , {}).get (
181
+ model_id , None
182
+ )
183
+ if mapped_file_id :
184
+ kwargs ["input_file_id" ] = mapped_file_id
185
+ return kwargs
186
+
175
187
def get_file_ids_from_messages (self , messages : List [AllMessageValues ]) -> List [str ]:
176
188
"""
177
189
Gets file ids from messages
@@ -192,37 +204,6 @@ def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[s
192
204
file_ids .append (file_id )
193
205
return file_ids
194
206
195
- @staticmethod
196
- def _convert_b64_uid_to_unified_uid (b64_uid : str ) -> str :
197
- is_base64_unified_file_id = (
198
- _PROXY_LiteLLMManagedFiles ._is_base64_encoded_unified_file_id (b64_uid )
199
- )
200
- if is_base64_unified_file_id :
201
- return is_base64_unified_file_id
202
- else :
203
- return b64_uid
204
-
205
- @staticmethod
206
- def _is_base64_encoded_unified_file_id (b64_uid : str ) -> Union [str , Literal [False ]]:
207
- # Add padding back if needed
208
- padded = b64_uid + "=" * (- len (b64_uid ) % 4 )
209
- # Decode from base64
210
- try :
211
- decoded = base64 .urlsafe_b64decode (padded ).decode ()
212
- if decoded .startswith (SpecialEnums .LITELM_MANAGED_FILE_ID_PREFIX .value ):
213
- return decoded
214
- else :
215
- return False
216
- except Exception :
217
- return False
218
-
219
- def convert_b64_uid_to_unified_uid (self , b64_uid : str ) -> str :
220
- is_base64_unified_file_id = self ._is_base64_encoded_unified_file_id (b64_uid )
221
- if is_base64_unified_file_id :
222
- return is_base64_unified_file_id
223
- else :
224
- return b64_uid
225
-
226
207
async def get_model_file_id_mapping (
227
208
self , file_ids : List [str ], litellm_parent_otel_span : Span
228
209
) -> dict :
@@ -247,7 +228,7 @@ async def get_model_file_id_mapping(
247
228
248
229
for file_id in file_ids :
249
230
## CHECK IF FILE ID IS MANAGED BY LITELM
250
- is_base64_unified_file_id = self . _is_base64_encoded_unified_file_id (file_id )
231
+ is_base64_unified_file_id = _is_base64_encoded_unified_file_id (file_id )
251
232
252
233
if is_base64_unified_file_id :
253
234
litellm_managed_file_ids .append (file_id )
@@ -300,6 +281,7 @@ async def acreate_file(
300
281
create_file_request = create_file_request ,
301
282
internal_usage_cache = self .internal_usage_cache ,
302
283
litellm_parent_otel_span = litellm_parent_otel_span ,
284
+ target_model_names_list = target_model_names_list ,
303
285
)
304
286
305
287
## STORE MODEL MAPPINGS IN DB
@@ -328,14 +310,15 @@ async def return_unified_file_id(
328
310
create_file_request : CreateFileRequest ,
329
311
internal_usage_cache : InternalUsageCache ,
330
312
litellm_parent_otel_span : Span ,
313
+ target_model_names_list : List [str ],
331
314
) -> OpenAIFileObject :
332
315
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
333
316
file_data = extract_file_data (create_file_request ["file" ])
334
317
335
318
file_type = file_data ["content_type" ]
336
319
337
320
unified_file_id = SpecialEnums .LITELLM_MANAGED_FILE_COMPLETE_STR .value .format (
338
- file_type , str (uuid .uuid4 ())
321
+ file_type , str (uuid .uuid4 ()), "," . join ( target_model_names_list )
339
322
)
340
323
341
324
# Convert to URL-safe base64 and strip padding
@@ -383,7 +366,7 @@ async def afile_delete(
383
366
llm_router : Router ,
384
367
** data : Dict ,
385
368
) -> OpenAIFileObject :
386
- file_id = self . convert_b64_uid_to_unified_uid (file_id )
369
+ file_id = convert_b64_uid_to_unified_uid (file_id )
387
370
model_file_id_mapping = await self .get_model_file_id_mapping (
388
371
[file_id ], litellm_parent_otel_span
389
372
)
0 commit comments