-
Notifications
You must be signed in to change notification settings - Fork 51
Support AI Gateway URLs in DatabricksOpenAI #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
fdbd45f
22311ff
e9435f6
a3e264d
f5ef30b
b129558
61a7cc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import os | ||
| from typing import Any, Generator | ||
| from urllib.parse import urlparse | ||
|
|
||
| from databricks.sdk import WorkspaceClient | ||
| from httpx import AsyncClient, Auth, Client, Request, Response | ||
|
|
@@ -102,6 +103,59 @@ def _fix_empty_assistant_content_in_messages(messages: Any) -> None: | |
| message["content"] = " " | ||
|
|
||
|
|
||
| def _get_ai_gateway_base_url( | ||
| http_client: Client, | ||
| host: str, | ||
| ) -> str | None: | ||
| """Check if AI Gateway V2 is enabled and return its base URL. | ||
|
|
||
| Calls GET /api/ai-gateway/v2/endpoints. If successful and endpoints exist, | ||
| extracts the ai_gateway_url from the first endpoint response. | ||
| Returns None if gateway is not available. | ||
| """ | ||
| try: | ||
| response = http_client.get(f"{host}/api/ai-gateway/v2/endpoints") | ||
| if response.status_code != 200: | ||
| return None | ||
| data = response.json() | ||
| endpoints = data.get("endpoints", []) | ||
| if not endpoints: | ||
| return None | ||
| gateway_url = endpoints[0].get("ai_gateway_url") | ||
| if not gateway_url: | ||
| return None | ||
| parsed = urlparse(gateway_url) | ||
| return f"{parsed.scheme}://{parsed.netloc}/mlflow/v1" | ||
| except Exception: | ||
| return None | ||
|
|
||
|
|
||
| def _resolve_base_url( | ||
| workspace_client: WorkspaceClient, | ||
| base_url: str | None, | ||
| use_ai_gateway: bool, | ||
| http_client: Client, | ||
| ) -> str: | ||
| """Resolve the target base URL for the OpenAI client.""" | ||
| if base_url is not None: | ||
| if _DATABRICKS_APPS_DOMAIN in base_url: | ||
| _validate_oauth_for_apps(workspace_client) | ||
| return base_url | ||
|
|
||
| # Prioritize using AI Gateway endpoints | ||
| if use_ai_gateway: | ||
| gateway_url = _get_ai_gateway_base_url(http_client, workspace_client.config.host) | ||
| if gateway_url: | ||
| return gateway_url | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Instead of falling through, lets throw an error here. If the user is specifically asing for use_ai_gateway and its not enabled or something, we should not silently swallow it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch - i'll raise a value error here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we say: "Please ensure AI Gateway V2 is enabled for the workspace when |
||
| raise ValueError( | ||
| "use_ai_gateway=True but AI Gateway V2 is not available for this workspace. " | ||
| "Set use_ai_gateway=False to use serving endpoints directly." | ||
| ) | ||
|
|
||
| # Fallback to using serving endpoints | ||
| return f"{workspace_client.config.host}/serving-endpoints" | ||
|
|
||
|
|
||
| def _get_authorized_http_client(workspace_client: WorkspaceClient) -> Client: | ||
| databricks_token_auth = BearerAuth(workspace_client.config.authenticate) | ||
| return Client(auth=databricks_token_auth) | ||
|
|
@@ -262,8 +316,10 @@ class DatabricksOpenAI(OpenAI): | |
| base_url: Optional base URL to override the default serving endpoints URL. When the URL | ||
| points to a Databricks App (contains "databricksapps"), OAuth authentication is | ||
| required. | ||
| use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route | ||
| requests through it. Defaults to False. | ||
|
|
||
| Example - Query a serving endpoint: | ||
| Example - Query a serving or AI gateway endpoint: | ||
| >>> client = DatabricksOpenAI() | ||
| >>> response = client.chat.completions.create( | ||
| ... model="databricks-meta-llama-3-1-70b-instruct", | ||
|
|
@@ -295,26 +351,21 @@ def __init__( | |
| self, | ||
| workspace_client: WorkspaceClient | None = None, | ||
| base_url: str | None = None, | ||
| use_ai_gateway: bool = False, | ||
| ): | ||
| if workspace_client is None: | ||
| workspace_client = WorkspaceClient() | ||
|
|
||
| self._workspace_client = workspace_client | ||
|
|
||
| if base_url is not None: | ||
| # Only validate OAuth for Databricks App URLs | ||
| if _DATABRICKS_APPS_DOMAIN in base_url: | ||
| _validate_oauth_for_apps(workspace_client) | ||
| target_base_url = base_url | ||
| else: | ||
| # Default: Serving endpoints | ||
| target_base_url = f"{workspace_client.config.host}/serving-endpoints" | ||
| http_client = _get_authorized_http_client(workspace_client) | ||
| target_base_url = _resolve_base_url(workspace_client, base_url, use_ai_gateway, http_client) | ||
|
|
||
| # Authentication is handled via http_client, not api_key | ||
| super().__init__( | ||
| base_url=target_base_url, | ||
| api_key=_get_openai_api_key(), | ||
| http_client=_get_authorized_http_client(workspace_client), | ||
| http_client=http_client, | ||
| ) | ||
|
|
||
| @override | ||
|
|
@@ -409,8 +460,10 @@ class AsyncDatabricksOpenAI(AsyncOpenAI): | |
| base_url: Optional base URL to override the default serving endpoints URL. When the URL | ||
| points to a Databricks App (contains "databricksapps"), OAuth authentication is | ||
| required. | ||
| use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route | ||
| requests through it. Defaults to False. | ||
|
|
||
| Example - Query a serving endpoint: | ||
| Example - Query a serving or AI gateway endpoint: | ||
| >>> client = AsyncDatabricksOpenAI() | ||
| >>> response = await client.chat.completions.create( | ||
| ... model="databricks-meta-llama-3-1-70b-instruct", | ||
|
|
@@ -442,20 +495,17 @@ def __init__( | |
| self, | ||
| workspace_client: WorkspaceClient | None = None, | ||
| base_url: str | None = None, | ||
| use_ai_gateway: bool = False, | ||
| ): | ||
| if workspace_client is None: | ||
| workspace_client = WorkspaceClient() | ||
|
|
||
| self._workspace_client = workspace_client | ||
|
|
||
| if base_url is not None: | ||
| # Only validate OAuth for Databricks App URLs | ||
| if _DATABRICKS_APPS_DOMAIN in base_url: | ||
| _validate_oauth_for_apps(workspace_client) | ||
| target_base_url = base_url | ||
| else: | ||
| # Default: Serving endpoints | ||
| target_base_url = f"{workspace_client.config.host}/serving-endpoints" | ||
| sync_http_client = _get_authorized_http_client(workspace_client) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using a sync http client here to make endpoint calls to gateway since it does a blocking http_client.get call in init |
||
| target_base_url = _resolve_base_url( | ||
| workspace_client, base_url, use_ai_gateway, sync_http_client | ||
| ) | ||
|
|
||
| # Authentication is handled via http_client, not api_key | ||
| super().__init__( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using direct http request to list ai gateway endpoints to check ai gateway status + get url so we don't assume hardcoded gateway url value here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for context example response to calling
https://dogfood.staging.databricks.com/ajax-api/ai-gateway/v2/endpointslooks like: