Skip to content

Commit bff6e85

Browse files
committed
support vertex_ai global endpoints for chat
1 parent b8b78f1 commit bff6e85

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

docs/my-website/docs/providers/vertex.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import TabItem from '@theme/TabItem';
1111
| Description | Vertex AI is a fully-managed AI development platform for building and using generative AI. |
1212
| Provider Route on LiteLLM | `vertex_ai/` |
1313
| Link to Provider Doc | [Vertex AI ↗](https://cloud.google.com/vertex-ai) |
14-
| Base URL | [https://{vertex_location}-aiplatform.googleapis.com/](https://{vertex_location}-aiplatform.googleapis.com/) |
14+
| Base URL | 1. Regional endpoints<br/>[https://{vertex_location}-aiplatform.googleapis.com/](https://{vertex_location}-aiplatform.googleapis.com/)<br/>2. Global endpoints (limited availability)<br/>[https://aiplatform.googleapis.com/](https://{aiplatform.googleapis.com/)|
1515
| Supported Operations | [`/chat/completions`](#sample-usage), `/completions`, [`/embeddings`](#embedding-models), [`/audio/speech`](#text-to-speech-apis), [`/fine_tuning`](#fine-tuning-apis), [`/batches`](#batch-apis), [`/files`](#batch-apis), [`/images`](#image-generation-models) |
1616

1717

@@ -832,7 +832,7 @@ OR
832832

833833
You can set:
834834
- `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json
835-
- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.)
835+
- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.). Some models support the global location, please see [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#supported_models)
836836
- `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials
837837

838838
as dynamic params for a `litellm.completion` call.

litellm/llms/vertex_ai/common_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@ def _get_vertex_url(
8484
endpoint = "generateContent"
8585
if stream is True:
8686
endpoint = "streamGenerateContent"
87-
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
87+
if vertex_location== "global":
88+
url = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}:{endpoint}?alt=sse"
89+
else:
90+
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
8891
else:
89-
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
92+
if vertex_location == "global":
93+
url = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}:{endpoint}"
94+
else:
95+
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
9096

9197
# if model is only numeric chars then it's a fine tuned gemini model
9298
# model = 4965075652664360960

tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_vertex_location_from_url,
1818
get_vertex_project_id_from_url,
1919
set_schema_property_ordering,
20+
_get_vertex_url
2021
)
2122

2223

@@ -292,3 +293,45 @@ def test_process_items_basic():
292293
}
293294
process_items(schema)
294295
assert schema["properties"]["nested"]["items"] == {"type": "object"}
296+
297+
@pytest.mark.parametrize(
298+
"stream, expected_endpoint_suffix",
299+
[
300+
(True, "streamGenerateContent?alt=sse"),
301+
(False, "generateContent"),
302+
],
303+
)
304+
def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
305+
"""
306+
Test _get_vertex_url when vertex_location is 'global' for chat mode.
307+
"""
308+
mode = "chat"
309+
model = "gemini-1.5-pro-preview-0409"
310+
vertex_project = "test-g-project"
311+
vertex_location = "global"
312+
vertex_api_version = "v1"
313+
314+
# Mock litellm.VertexGeminiConfig.get_model_for_vertex_ai_url to return model as is
315+
# as we are not testing that part here, just the URL construction
316+
with patch("litellm.VertexGeminiConfig.get_model_for_vertex_ai_url", side_effect=lambda model: model):
317+
url, endpoint = _get_vertex_url(
318+
mode=mode,
319+
model=model,
320+
stream=stream,
321+
vertex_project=vertex_project,
322+
vertex_location=vertex_location,
323+
vertex_api_version=vertex_api_version,
324+
)
325+
326+
expected_url_base = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}"
327+
328+
if stream:
329+
expected_endpoint = "streamGenerateContent"
330+
expected_url = f"{expected_url_base}:{expected_endpoint}?alt=sse"
331+
else:
332+
expected_endpoint = "generateContent"
333+
expected_url = f"{expected_url_base}:{expected_endpoint}"
334+
335+
336+
assert endpoint == expected_endpoint
337+
assert url == expected_url

0 commit comments

Comments
 (0)