Skip to content

Commit b8b78f1

Browse files
Support unified file id (managed files) for batches (#10650)
* refactor(managed_files.py): move enterprise feature into enterprise folder prevent unexpected surprises * refactor: safely handle enterprise hooks * fix: fix ruff check errors * fix(files_endpoints.py): cleanup enterprise code from OSS * refactor: complete cleanup * fix(managed_files.py): complete cleanup * fix(managed_files.py): instrument to be able to update deployment values post-router selection and just before making llm call * fix(managed_files.py): instrument to be able to update deployment values post-router selection and just before making llm call * fix: fix linting error * fix: fix linting error
1 parent fcaa4a9 commit b8b78f1

File tree

17 files changed

+370
-179
lines changed

17 files changed

+370
-179
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
from typing import Dict, Literal, Type, Union
3+
4+
from litellm.integrations.custom_logger import CustomLogger
5+
6+
from .managed_files import _PROXY_LiteLLMManagedFiles
7+
from .parallel_request_limiter_v2 import _PROXY_MaxParallelRequestsHandler
8+
9+
ENTERPRISE_PROXY_HOOKS: Dict[str, Type[CustomLogger]] = {
10+
"managed_files": _PROXY_LiteLLMManagedFiles,
11+
}
12+
13+
14+
## FEATURE FLAG HOOKS ##
15+
16+
if os.getenv("EXPERIMENTAL_MULTI_INSTANCE_RATE_LIMITING", "false").lower() == "true":
17+
ENTERPRISE_PROXY_HOOKS["max_parallel_requests"] = _PROXY_MaxParallelRequestsHandler
18+
19+
20+
def get_enterprise_proxy_hook(
21+
hook_name: Union[
22+
Literal[
23+
"managed_files",
24+
"max_parallel_requests",
25+
],
26+
str,
27+
]
28+
):
29+
"""
30+
Factory method to get a enterprise hook instance by name
31+
"""
32+
if hook_name not in ENTERPRISE_PROXY_HOOKS:
33+
raise ValueError(
34+
f"Unknown hook: {hook_name}. Available hooks: {list(ENTERPRISE_PROXY_HOOKS.keys())}"
35+
)
36+
return ENTERPRISE_PROXY_HOOKS[hook_name]

litellm/proxy/hooks/managed_files.py renamed to enterprise/enterprise_hooks/managed_files.py

Lines changed: 41 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
import base64
55
import json
66
import uuid
7-
from abc import ABC, abstractmethod
87
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
98

109
from litellm import Router, verbose_logger
1110
from litellm.caching.caching import DualCache
1211
from litellm.integrations.custom_logger import CustomLogger
1312
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
13+
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
1414
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
15+
from litellm.proxy.openai_files_endpoints.common_utils import (
16+
_is_base64_encoded_unified_file_id,
17+
convert_b64_uid_to_unified_uid,
18+
)
1519
from litellm.types.llms.openai import (
1620
AllMessageValues,
1721
ChatCompletionFileObject,
@@ -36,29 +40,7 @@
3640
PrismaClient = Any
3741

3842

39-
class BaseFileEndpoints(ABC):
40-
@abstractmethod
41-
async def afile_retrieve(
42-
self,
43-
file_id: str,
44-
litellm_parent_otel_span: Optional[Span],
45-
) -> OpenAIFileObject:
46-
pass
47-
48-
@abstractmethod
49-
async def afile_list(
50-
self, custom_llm_provider: str, **data: dict
51-
) -> List[OpenAIFileObject]:
52-
pass
53-
54-
@abstractmethod
55-
async def afile_delete(
56-
self, custom_llm_provider: str, file_id: str, **data: dict
57-
) -> OpenAIFileObject:
58-
pass
59-
60-
61-
class _PROXY_LiteLLMManagedFiles(CustomLogger):
43+
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
6244
# Class variables or attributes
6345
def __init__(
6446
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
@@ -153,12 +135,14 @@ async def async_pre_call_hook(
153135
"audio_transcription",
154136
"pass_through_endpoint",
155137
"rerank",
138+
"acreate_batch",
156139
],
157140
) -> Union[Exception, str, Dict, None]:
158141
"""
159142
- Detect litellm_proxy/ file_id
160143
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
161144
"""
145+
print("REACHES async_pre_call_hook, call_type:", call_type)
162146
if call_type == CallTypes.completion.value:
163147
messages = data.get("messages")
164148
if messages:
@@ -169,9 +153,37 @@ async def async_pre_call_hook(
169153
)
170154

171155
data["model_file_id_mapping"] = model_file_id_mapping
156+
elif call_type == CallTypes.acreate_batch.value:
157+
input_file_id = cast(Optional[str], data.get("input_file_id"))
158+
if input_file_id:
159+
model_file_id_mapping = await self.get_model_file_id_mapping(
160+
[input_file_id], user_api_key_dict.parent_otel_span
161+
)
172162

163+
data["model_file_id_mapping"] = model_file_id_mapping
173164
return data
174165

166+
async def async_pre_call_deployment_hook(
167+
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
168+
) -> Optional[dict]:
169+
"""
170+
Allow modifying the request just before it's sent to the deployment.
171+
"""
172+
if call_type and call_type == CallTypes.acreate_batch:
173+
input_file_id = cast(Optional[str], kwargs.get("input_file_id"))
174+
model_file_id_mapping = cast(
175+
Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping")
176+
)
177+
model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None))
178+
mapped_file_id: Optional[str] = None
179+
if input_file_id and model_file_id_mapping and model_id:
180+
mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get(
181+
model_id, None
182+
)
183+
if mapped_file_id:
184+
kwargs["input_file_id"] = mapped_file_id
185+
return kwargs
186+
175187
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
176188
"""
177189
Gets file ids from messages
@@ -192,37 +204,6 @@ def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[s
192204
file_ids.append(file_id)
193205
return file_ids
194206

195-
@staticmethod
196-
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
197-
is_base64_unified_file_id = (
198-
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
199-
)
200-
if is_base64_unified_file_id:
201-
return is_base64_unified_file_id
202-
else:
203-
return b64_uid
204-
205-
@staticmethod
206-
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
207-
# Add padding back if needed
208-
padded = b64_uid + "=" * (-len(b64_uid) % 4)
209-
# Decode from base64
210-
try:
211-
decoded = base64.urlsafe_b64decode(padded).decode()
212-
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
213-
return decoded
214-
else:
215-
return False
216-
except Exception:
217-
return False
218-
219-
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
220-
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
221-
if is_base64_unified_file_id:
222-
return is_base64_unified_file_id
223-
else:
224-
return b64_uid
225-
226207
async def get_model_file_id_mapping(
227208
self, file_ids: List[str], litellm_parent_otel_span: Span
228209
) -> dict:
@@ -247,7 +228,7 @@ async def get_model_file_id_mapping(
247228

248229
for file_id in file_ids:
249230
## CHECK IF FILE ID IS MANAGED BY LITELM
250-
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
231+
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
251232

252233
if is_base64_unified_file_id:
253234
litellm_managed_file_ids.append(file_id)
@@ -300,6 +281,7 @@ async def acreate_file(
300281
create_file_request=create_file_request,
301282
internal_usage_cache=self.internal_usage_cache,
302283
litellm_parent_otel_span=litellm_parent_otel_span,
284+
target_model_names_list=target_model_names_list,
303285
)
304286

305287
## STORE MODEL MAPPINGS IN DB
@@ -328,14 +310,15 @@ async def return_unified_file_id(
328310
create_file_request: CreateFileRequest,
329311
internal_usage_cache: InternalUsageCache,
330312
litellm_parent_otel_span: Span,
313+
target_model_names_list: List[str],
331314
) -> OpenAIFileObject:
332315
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
333316
file_data = extract_file_data(create_file_request["file"])
334317

335318
file_type = file_data["content_type"]
336319

337320
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
338-
file_type, str(uuid.uuid4())
321+
file_type, str(uuid.uuid4()), ",".join(target_model_names_list)
339322
)
340323

341324
# Convert to URL-safe base64 and strip padding
@@ -383,7 +366,7 @@ async def afile_delete(
383366
llm_router: Router,
384367
**data: Dict,
385368
) -> OpenAIFileObject:
386-
file_id = self.convert_b64_uid_to_unified_uid(file_id)
369+
file_id = convert_b64_uid_to_unified_uid(file_id)
387370
model_file_id_mapping = await self.get_model_file_id_mapping(
388371
[file_id], litellm_parent_otel_span
389372
)

litellm/integrations/custom_logger.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
2222
from litellm.types.utils import (
2323
AdapterCompletionStreamWrapper,
24+
CallTypes,
2425
LLMResponseTypes,
2526
ModelResponse,
2627
ModelResponseStream,
@@ -127,6 +128,18 @@ async def async_filter_deployments(
127128
) -> List[dict]:
128129
return healthy_deployments
129130

131+
async def async_pre_call_deployment_hook(
132+
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
133+
) -> Optional[dict]:
134+
"""
135+
Allow modifying the request just before it's sent to the deployment.
136+
137+
Use this instead of 'async_pre_call_hook' when you need to modify the request AFTER a deployment is selected, but BEFORE the request is sent.
138+
139+
Used in managed_files.py
140+
"""
141+
pass
142+
130143
async def async_pre_call_check(
131144
self, deployment: dict, parent_otel_span: Optional[Span]
132145
) -> Optional[dict]:

litellm/litellm_core_utils/prompt_templates/common_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,14 @@ def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
346346
unified_file_id = litellm_proxy:{};unified_id,{}
347347
If not a unified file id, returns 'file' as default format
348348
"""
349-
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
349+
from litellm.proxy.openai_files_endpoints.common_utils import (
350+
convert_b64_uid_to_unified_uid,
351+
)
350352

351353
if not file_id:
352354
return None
353355
try:
354-
transformed_file_id = (
355-
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(file_id)
356-
)
356+
transformed_file_id = convert_b64_uid_to_unified_uid(file_id)
357357
if transformed_file_id.startswith(
358358
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
359359
):

litellm/llms/base_llm/files/transformation.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from abc import abstractmethod
2-
from typing import TYPE_CHECKING, Any, List, Optional, Union
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
33

44
import httpx
55

@@ -8,17 +8,23 @@
88
CreateFileRequest,
99
OpenAICreateFileRequestOptionalParams,
1010
OpenAIFileObject,
11+
OpenAIFilesPurpose,
1112
)
1213
from litellm.types.utils import LlmProviders, ModelResponse
1314

1415
from ..chat.transformation import BaseConfig
1516

1617
if TYPE_CHECKING:
1718
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
19+
from litellm.router import Router as _Router
1820

1921
LiteLLMLoggingObj = _LiteLLMLoggingObj
22+
Span = Any
23+
Router = _Router
2024
else:
2125
LiteLLMLoggingObj = Any
26+
Span = Any
27+
Router = Any
2228

2329

2430
class BaseFilesConfig(BaseConfig):
@@ -99,3 +105,52 @@ def transform_response(
99105
raise NotImplementedError(
100106
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
101107
)
108+
109+
110+
class BaseFileEndpoints(ABC):
111+
@abstractmethod
112+
async def acreate_file(
113+
self,
114+
create_file_request: CreateFileRequest,
115+
llm_router: Router,
116+
target_model_names_list: List[str],
117+
litellm_parent_otel_span: Span,
118+
) -> OpenAIFileObject:
119+
pass
120+
121+
@abstractmethod
122+
async def afile_retrieve(
123+
self,
124+
file_id: str,
125+
litellm_parent_otel_span: Optional[Span],
126+
) -> OpenAIFileObject:
127+
pass
128+
129+
@abstractmethod
130+
async def afile_list(
131+
self,
132+
purpose: Optional[OpenAIFilesPurpose],
133+
litellm_parent_otel_span: Optional[Span],
134+
**data: Dict,
135+
) -> List[OpenAIFileObject]:
136+
pass
137+
138+
@abstractmethod
139+
async def afile_delete(
140+
self,
141+
file_id: str,
142+
litellm_parent_otel_span: Optional[Span],
143+
llm_router: Router,
144+
**data: Dict,
145+
) -> OpenAIFileObject:
146+
pass
147+
148+
@abstractmethod
149+
async def afile_content(
150+
self,
151+
file_id: str,
152+
litellm_parent_otel_span: Optional[Span],
153+
llm_router: Router,
154+
**data: Dict,
155+
) -> str:
156+
pass

litellm/proxy/_new_secret_config.yaml

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,9 @@
11
model_list:
2-
- model_name: gpt-4o-mini-tts
2+
- model_name: "gemini-2.0-flash"
33
litellm_params:
4-
model: openai/gpt-4o-mini-tts
5-
api_key: os.environ/OPENAI_API_KEY
6-
- model_name: gpt-3.5-turbo
7-
litellm_params:
8-
model: azure/chatgpt-v-3
9-
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
10-
api_version: "2023-05-15"
11-
api_key: os.environ/AZURE_API_KEY
12-
- model_name: "gpt-4o-azure"
13-
litellm_params:
14-
model: azure/gpt-4o
15-
api_key: os.environ/AZURE_API_KEY
16-
api_base: os.environ/AZURE_API_BASE
17-
- model_name: fake-openai-endpoint
18-
litellm_params:
19-
model: openai/fake
20-
api_key: fake-key
21-
api_base: https://exampleopenaiendpoint-production.up.railway.app/
4+
model: vertex_ai/gemini-2.0-flash
5+
vertex_project: my-project-id
6+
vertex_location: us-central1
227
- model_name: "gpt-4o-mini-openai"
238
litellm_params:
249
model: gpt-4o-mini

0 commit comments

Comments
 (0)