Skip to content

Commit ff68cbd

Browse files
Fix azure dall e 3 call with custom model name + Handle Bearer $LITELLM_API_KEY in x-litellm-api-key custom header (#10776)
* fix(main.py): use base model instead of user model if given Fixes #10760 * feat(azure/image_generation/__init__.py): make azure image gen check more robust Fixes #10760 * fix(user_api_key_auth.py): support bearer token auth for `x-litellm-api-key` header Fixes earlier regression on vertex ai passthrough auth * fix(user_api_key_auth.py): refactor get api key into separate function enables easier testing * fix: cleanup * fix: fix linting error * fix: cleanup * test: update tests
1 parent 53f6514 commit ff68cbd

File tree

11 files changed

+285
-143
lines changed

11 files changed

+285
-143
lines changed

litellm/llms/azure/image_generation/__init__.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from litellm._logging import verbose_logger
12
from litellm.llms.base_llm.image_generation.transformation import (
23
BaseImageGenerationConfig,
34
)
@@ -14,9 +15,15 @@
1415

1516

1617
def get_azure_image_generation_config(model: str) -> BaseImageGenerationConfig:
17-
if model.startswith("dall-e-2") or model == "": # empty model is dall-e-2
18+
model = model.lower()
19+
model = model.replace("-", "")
20+
model = model.replace("_", "")
21+
if model == "" or "dalle2" in model: # empty model is dall-e-2
1822
return AzureDallE2ImageGenerationConfig()
19-
elif model.startswith("dall-e-3"):
23+
elif "dalle3" in model:
2024
return AzureDallE3ImageGenerationConfig()
2125
else:
26+
verbose_logger.debug(
27+
f"Using AzureGPTImageGenerationConfig for model: {model}. This follows the gpt-image-1 model format."
28+
)
2229
return AzureGPTImageGenerationConfig()

litellm/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4661,6 +4661,7 @@ def image_generation( # noqa: PLR0915
46614661
client = kwargs.get("client", None)
46624662
extra_headers = kwargs.get("extra_headers", None)
46634663
headers: dict = kwargs.get("headers", None) or {}
4664+
base_model = kwargs.get("base_model", None)
46644665
if extra_headers is not None:
46654666
headers.update(extra_headers)
46664667
model_response: ImageResponse = litellm.utils.ImageResponse()
@@ -4705,13 +4706,13 @@ def image_generation( # noqa: PLR0915
47054706
):
47064707
image_generation_config = (
47074708
ProviderConfigManager.get_provider_image_generation_config(
4708-
model=model,
4709+
model=base_model or model,
47094710
provider=LlmProviders(custom_llm_provider),
47104711
)
47114712
)
47124713

47134714
optional_params = get_optional_params_image_gen(
4714-
model=model,
4715+
model=base_model or model,
47154716
n=n,
47164717
quality=quality,
47174718
response_format=response_format,

litellm/proxy/_experimental/out/onboarding.html

-1
This file was deleted.

litellm/proxy/auth/user_api_key_auth.py

+70-23
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import asyncio
1111
import secrets
1212
from datetime import datetime, timezone
13-
from typing import Optional, cast
13+
from typing import List, Optional, Tuple, cast
1414

1515
import fastapi
1616
from fastapi import HTTPException, Request, WebSocket, status
@@ -89,6 +89,17 @@
8989
)
9090

9191

92+
def _get_bearer_token_or_received_api_key(api_key: str) -> str:
93+
if api_key.startswith("Bearer "): # ensure Bearer token passed in
94+
api_key = api_key.replace("Bearer ", "") # extract the token
95+
elif api_key.startswith("Basic "):
96+
api_key = api_key.replace("Basic ", "") # handle langfuse input
97+
elif api_key.startswith("bearer "):
98+
api_key = api_key.replace("bearer ", "")
99+
100+
return api_key
101+
102+
92103
def _get_bearer_token(
93104
api_key: str,
94105
):
@@ -217,6 +228,53 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
217228
return LitellmUserRoles.TEAM
218229

219230

231+
def get_api_key(
232+
custom_litellm_key_header: Optional[str],
233+
api_key: str,
234+
azure_api_key_header: Optional[str],
235+
anthropic_api_key_header: Optional[str],
236+
google_ai_studio_api_key_header: Optional[str],
237+
azure_apim_header: Optional[str],
238+
pass_through_endpoints: Optional[List[dict]],
239+
route: str,
240+
request: Request,
241+
) -> Tuple[str, Optional[str]]:
242+
"""
243+
Returns:
244+
Tuple[Optional[str], Optional[str]]: Tuple of the api_key and the passed_in_key
245+
"""
246+
api_key = api_key
247+
passed_in_key: Optional[str] = None
248+
if isinstance(custom_litellm_key_header, str):
249+
passed_in_key = custom_litellm_key_header
250+
api_key = _get_bearer_token_or_received_api_key(custom_litellm_key_header)
251+
elif isinstance(api_key, str):
252+
passed_in_key = api_key
253+
api_key = _get_bearer_token(api_key=api_key)
254+
elif isinstance(azure_api_key_header, str):
255+
passed_in_key = azure_api_key_header
256+
api_key = azure_api_key_header
257+
elif isinstance(anthropic_api_key_header, str):
258+
passed_in_key = anthropic_api_key_header
259+
api_key = anthropic_api_key_header
260+
elif isinstance(google_ai_studio_api_key_header, str):
261+
passed_in_key = google_ai_studio_api_key_header
262+
api_key = google_ai_studio_api_key_header
263+
elif isinstance(azure_apim_header, str):
264+
passed_in_key = azure_apim_header
265+
api_key = azure_apim_header
266+
elif pass_through_endpoints is not None:
267+
for endpoint in pass_through_endpoints:
268+
if endpoint.get("path", "") == route:
269+
headers: Optional[dict] = endpoint.get("headers", None)
270+
if headers is not None:
271+
header_key: str = headers.get("litellm_user_api_key", "")
272+
if request.headers.get(key=header_key) is not None:
273+
api_key = request.headers.get(key=header_key)
274+
passed_in_key = api_key
275+
return api_key, passed_in_key
276+
277+
220278
async def _user_api_key_auth_builder( # noqa: PLR0915
221279
request: Request,
222280
api_key: str,
@@ -260,28 +318,17 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
260318
)
261319
passed_in_key: Optional[str] = None
262320
## CHECK IF X-LITELM-API-KEY IS PASSED IN - supercedes Authorization header
263-
if isinstance(custom_litellm_key_header, str):
264-
api_key = custom_litellm_key_header
265-
elif isinstance(api_key, str):
266-
passed_in_key = api_key
267-
api_key = _get_bearer_token(api_key=api_key)
268-
elif isinstance(azure_api_key_header, str):
269-
api_key = azure_api_key_header
270-
elif isinstance(anthropic_api_key_header, str):
271-
api_key = anthropic_api_key_header
272-
elif isinstance(google_ai_studio_api_key_header, str):
273-
api_key = google_ai_studio_api_key_header
274-
elif isinstance(azure_apim_header, str):
275-
api_key = azure_apim_header
276-
elif pass_through_endpoints is not None:
277-
for endpoint in pass_through_endpoints:
278-
if endpoint.get("path", "") == route:
279-
headers: Optional[dict] = endpoint.get("headers", None)
280-
if headers is not None:
281-
header_key: str = headers.get("litellm_user_api_key", "")
282-
if request.headers.get(key=header_key) is not None:
283-
api_key = request.headers.get(key=header_key)
284-
321+
api_key, passed_in_key = get_api_key(
322+
custom_litellm_key_header=custom_litellm_key_header,
323+
api_key=api_key,
324+
azure_api_key_header=azure_api_key_header,
325+
anthropic_api_key_header=anthropic_api_key_header,
326+
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
327+
azure_apim_header=azure_apim_header,
328+
pass_through_endpoints=pass_through_endpoints,
329+
route=route,
330+
request=request,
331+
)
285332
# if user wants to pass LiteLLM_Master_Key as a custom header, example pass litellm keys as X-LiteLLM-Key: Bearer sk-1234
286333
custom_litellm_key_header_name = general_settings.get("litellm_key_header_name")
287334
if custom_litellm_key_header_name is not None:

litellm/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2271,7 +2271,7 @@ def _check_valid_arg(supported_params):
22712271
elif k not in supported_params:
22722272
raise UnsupportedParamsError(
22732273
status_code=500,
2274-
message=f"Setting `{k}` is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
2274+
message=f"Setting `{k}` is not supported by {custom_llm_provider}, {model}. To drop it from the call, set `litellm.drop_params = True`.",
22752275
)
22762276
return non_default_params
22772277

tests/image_gen_tests/test_image_generation.py

+1
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,4 @@ async def test_aimage_generation_bedrock_with_optional_params():
240240
pass
241241
else:
242242
pytest.fail(f"An exception occurred - {str(e)}")
243+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import json
2+
import os
3+
import sys
4+
import traceback
5+
from typing import Callable, Optional
6+
from unittest.mock import MagicMock, patch
7+
8+
import pytest
9+
10+
sys.path.insert(
11+
0, os.path.abspath("../../../../..")
12+
) # Adds the parent directory to the system path
13+
import litellm
14+
from litellm.llms.azure.image_generation import (
15+
AzureDallE3ImageGenerationConfig,
16+
get_azure_image_generation_config,
17+
)
18+
19+
20+
@pytest.mark.parametrize(
21+
"received_model, expected_config",
22+
[
23+
("dall-e-3", AzureDallE3ImageGenerationConfig),
24+
("dalle-3", AzureDallE3ImageGenerationConfig),
25+
("openai_dall_e_3", AzureDallE3ImageGenerationConfig),
26+
],
27+
)
28+
def test_azure_image_generation_config(received_model, expected_config):
29+
assert isinstance(
30+
get_azure_image_generation_config(received_model), expected_config
31+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import asyncio
2+
import json
3+
import os
4+
import sys
5+
from typing import Tuple
6+
from unittest.mock import AsyncMock, MagicMock, patch
7+
8+
sys.path.insert(
9+
0, os.path.abspath("../../..")
10+
) # Adds the parent directory to the system path
11+
12+
from unittest.mock import MagicMock
13+
14+
import pytest
15+
16+
from litellm.proxy.auth.user_api_key_auth import get_api_key
17+
18+
19+
def test_get_api_key():
20+
bearer_token = "Bearer sk-12345678"
21+
api_key = "sk-12345678"
22+
passed_in_key = "Bearer sk-12345678"
23+
assert get_api_key(
24+
custom_litellm_key_header=None,
25+
api_key=bearer_token,
26+
azure_api_key_header=None,
27+
anthropic_api_key_header=None,
28+
google_ai_studio_api_key_header=None,
29+
azure_apim_header=None,
30+
pass_through_endpoints=None,
31+
route="",
32+
request=MagicMock(),
33+
) == (api_key, passed_in_key)
34+
35+
36+
@pytest.mark.parametrize(
37+
"custom_litellm_key_header, api_key, passed_in_key",
38+
[
39+
("Bearer sk-12345678", "sk-12345678", "Bearer sk-12345678"),
40+
("Basic sk-12345678", "sk-12345678", "Basic sk-12345678"),
41+
("bearer sk-12345678", "sk-12345678", "bearer sk-12345678"),
42+
("sk-12345678", "sk-12345678", "sk-12345678"),
43+
],
44+
)
45+
def test_get_api_key_with_custom_litellm_key_header(
46+
custom_litellm_key_header, api_key, passed_in_key
47+
):
48+
assert get_api_key(
49+
custom_litellm_key_header=custom_litellm_key_header,
50+
api_key=None,
51+
azure_api_key_header=None,
52+
anthropic_api_key_header=None,
53+
google_ai_studio_api_key_header=None,
54+
azure_apim_header=None,
55+
pass_through_endpoints=None,
56+
route="",
57+
request=MagicMock(),
58+
) == (api_key, passed_in_key)

tests/litellm/test_utils.py

-113
Original file line numberDiff line numberDiff line change
@@ -32,119 +32,6 @@ def test_get_optional_params_image_gen():
3232
assert optional_params["n"] == 3
3333

3434

35-
def return_mocked_response(model: str):
36-
if model == "bedrock/mistral.mistral-large-2407-v1:0":
37-
return {
38-
"metrics": {"latencyMs": 316},
39-
"output": {
40-
"message": {
41-
"content": [{"text": "Hello! How are you doing today? How can"}],
42-
"role": "assistant",
43-
}
44-
},
45-
"stopReason": "max_tokens",
46-
"usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15},
47-
}
48-
49-
50-
@pytest.mark.parametrize(
51-
"model",
52-
[
53-
"bedrock/mistral.mistral-large-2407-v1:0",
54-
],
55-
)
56-
@pytest.mark.asyncio()
57-
async def test_bedrock_max_completion_tokens(model: str):
58-
"""
59-
Tests that:
60-
- max_completion_tokens is passed as max_tokens to bedrock models
61-
"""
62-
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
63-
64-
litellm.set_verbose = True
65-
66-
client = AsyncHTTPHandler()
67-
68-
mock_response = return_mocked_response(model)
69-
_model = model.split("/")[1]
70-
print("\n\nmock_response: ", mock_response)
71-
72-
with patch.object(client, "post") as mock_client:
73-
try:
74-
response = await litellm.acompletion(
75-
model=model,
76-
max_completion_tokens=10,
77-
messages=[{"role": "user", "content": "Hello!"}],
78-
client=client,
79-
)
80-
except Exception as e:
81-
print(f"Error: {e}")
82-
83-
mock_client.assert_called_once()
84-
request_body = json.loads(mock_client.call_args.kwargs["data"])
85-
86-
print("request_body: ", request_body)
87-
88-
assert request_body == {
89-
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
90-
"additionalModelRequestFields": {},
91-
"system": [],
92-
"inferenceConfig": {"maxTokens": 10},
93-
}
94-
95-
96-
@pytest.mark.parametrize(
97-
"model",
98-
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"],
99-
)
100-
@pytest.mark.asyncio()
101-
async def test_anthropic_api_max_completion_tokens(model: str):
102-
"""
103-
Tests that:
104-
- max_completion_tokens is passed as max_tokens to anthropic models
105-
"""
106-
litellm.set_verbose = True
107-
from litellm.llms.custom_httpx.http_handler import HTTPHandler
108-
109-
mock_response = {
110-
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
111-
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
112-
"model": "claude-3-5-sonnet-20240620",
113-
"role": "assistant",
114-
"stop_reason": "end_turn",
115-
"stop_sequence": None,
116-
"type": "message",
117-
"usage": {"input_tokens": 2095, "output_tokens": 503},
118-
}
119-
120-
client = HTTPHandler()
121-
122-
print("\n\nmock_response: ", mock_response)
123-
124-
with patch.object(client, "post") as mock_client:
125-
try:
126-
response = await litellm.acompletion(
127-
model=model,
128-
max_completion_tokens=10,
129-
messages=[{"role": "user", "content": "Hello!"}],
130-
client=client,
131-
)
132-
except Exception as e:
133-
print(f"Error: {e}")
134-
mock_client.assert_called_once()
135-
request_body = mock_client.call_args.kwargs["json"]
136-
137-
print("request_body: ", request_body)
138-
139-
assert request_body == {
140-
"messages": [
141-
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}
142-
],
143-
"max_tokens": 10,
144-
"model": model.split("/")[-1],
145-
}
146-
147-
14835
def test_all_model_configs():
14936
from litellm.llms.vertex_ai.vertex_ai_partner_models.ai21.transformation import (
15037
VertexAIAi21Config,

0 commit comments

Comments
 (0)