Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions integrations/openai/src/databricks_openai/utils/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Generator
from urllib.parse import urlparse

from databricks.sdk import WorkspaceClient
from httpx import AsyncClient, Auth, Client, Request, Response
Expand Down Expand Up @@ -102,6 +103,59 @@ def _fix_empty_assistant_content_in_messages(messages: Any) -> None:
message["content"] = " "


def _get_ai_gateway_base_url(
http_client: Client,
host: str,
) -> str | None:
"""Check if AI Gateway V2 is enabled and return its base URL.

Calls GET /api/ai-gateway/v2/endpoints. If successful and endpoints exist,
extracts the ai_gateway_url from the first endpoint response.
Returns None if gateway is not available.
"""
try:
response = http_client.get(f"{host}/api/ai-gateway/v2/endpoints")
if response.status_code != 200:
return None
data = response.json()
endpoints = data.get("endpoints", [])
if not endpoints:
return None
gateway_url = endpoints[0].get("ai_gateway_url")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using direct http request to list ai gateway endpoints to check ai gateway status + get url so we don't assume hardcoded gateway url value here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for context example response to calling https://dogfood.staging.databricks.com/ajax-api/ai-gateway/v2/endpoints looks like:

 "endpoints": [
        {
            "name": "databricks-bge-large-en",
            "id": "41eaa162-d05a-3401-bf27-2a040480e808",
            "created_by": "Databricks",
            "config": {
                "destinations": [
                    {
                        "name": "system.ai.bge_large_en_v1_5",
                        "type": "PAY_PER_TOKEN_FOUNDATION_MODEL",
                        "traffic_percentage": 100,
                        "id": "01a22d18007038459db772d12ef43b3d"
                    }
                ],
                "usage_tracking": {
                    "enabled": true
                }
            },
            "created_timestamp": "1699610000000",
            "last_updated_timestamp": "1699610000000",
            "ai_gateway_url": "https://6051921418418893.ai-gateway.staging.cloud.databricks.com"
        },

if not gateway_url:
return None
parsed = urlparse(gateway_url)
return f"{parsed.scheme}://{parsed.netloc}/mlflow/v1"
except Exception:
return None


def _resolve_base_url(
workspace_client: WorkspaceClient,
base_url: str | None,
use_ai_gateway: bool,
http_client: Client,
) -> str:
"""Resolve the target base URL for the OpenAI client."""
if base_url is not None:
if _DATABRICKS_APPS_DOMAIN in base_url:
_validate_oauth_for_apps(workspace_client)
return base_url

# Prioritize using AI Gateway endpoints
if use_ai_gateway:
gateway_url = _get_ai_gateway_base_url(http_client, workspace_client.config.host)
if gateway_url:
return gateway_url
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of falling through, lets throw an error here. If the user is specifically asing for use_ai_gateway and its not enabled or something, we should not silently swallow it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch - i'll raise a value error here

Copy link
Copy Markdown
Contributor

@aravind-segu aravind-segu Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we say:

"Please ensure AI Gateway V2 is enabled for the workspace when use_ai_gateway is set to True

raise ValueError(
"use_ai_gateway=True but AI Gateway V2 is not available for this workspace. "
"Set use_ai_gateway=False to use serving endpoints directly."
)

# Fallback to using serving endpoints
return f"{workspace_client.config.host}/serving-endpoints"


def _get_authorized_http_client(workspace_client: WorkspaceClient) -> Client:
databricks_token_auth = BearerAuth(workspace_client.config.authenticate)
return Client(auth=databricks_token_auth)
Expand Down Expand Up @@ -262,8 +316,10 @@ class DatabricksOpenAI(OpenAI):
base_url: Optional base URL to override the default serving endpoints URL. When the URL
points to a Databricks App (contains "databricksapps"), OAuth authentication is
required.
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
requests through it. Defaults to False.

Example - Query a serving endpoint:
Example - Query a serving or AI gateway endpoint:
>>> client = DatabricksOpenAI()
>>> response = client.chat.completions.create(
... model="databricks-meta-llama-3-1-70b-instruct",
Expand Down Expand Up @@ -295,26 +351,21 @@ def __init__(
self,
workspace_client: WorkspaceClient | None = None,
base_url: str | None = None,
use_ai_gateway: bool = False,
):
if workspace_client is None:
workspace_client = WorkspaceClient()

self._workspace_client = workspace_client

if base_url is not None:
# Only validate OAuth for Databricks App URLs
if _DATABRICKS_APPS_DOMAIN in base_url:
_validate_oauth_for_apps(workspace_client)
target_base_url = base_url
else:
# Default: Serving endpoints
target_base_url = f"{workspace_client.config.host}/serving-endpoints"
http_client = _get_authorized_http_client(workspace_client)
target_base_url = _resolve_base_url(workspace_client, base_url, use_ai_gateway, http_client)

# Authentication is handled via http_client, not api_key
super().__init__(
base_url=target_base_url,
api_key=_get_openai_api_key(),
http_client=_get_authorized_http_client(workspace_client),
http_client=http_client,
)

@override
Expand Down Expand Up @@ -409,8 +460,10 @@ class AsyncDatabricksOpenAI(AsyncOpenAI):
base_url: Optional base URL to override the default serving endpoints URL. When the URL
points to a Databricks App (contains "databricksapps"), OAuth authentication is
required.
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
requests through it. Defaults to False.

Example - Query a serving endpoint:
Example - Query a serving or AI gateway endpoint:
>>> client = AsyncDatabricksOpenAI()
>>> response = await client.chat.completions.create(
... model="databricks-meta-llama-3-1-70b-instruct",
Expand Down Expand Up @@ -442,20 +495,17 @@ def __init__(
self,
workspace_client: WorkspaceClient | None = None,
base_url: str | None = None,
use_ai_gateway: bool = False,
):
if workspace_client is None:
workspace_client = WorkspaceClient()

self._workspace_client = workspace_client

if base_url is not None:
# Only validate OAuth for Databricks App URLs
if _DATABRICKS_APPS_DOMAIN in base_url:
_validate_oauth_for_apps(workspace_client)
target_base_url = base_url
else:
# Default: Serving endpoints
target_base_url = f"{workspace_client.config.host}/serving-endpoints"
sync_http_client = _get_authorized_http_client(workspace_client)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a sync http client here to make endpoint calls to gateway since it does a blocking http_client.get call in init

target_base_url = _resolve_base_url(
workspace_client, base_url, use_ai_gateway, sync_http_client
)

# Authentication is handled via http_client, not api_key
super().__init__(
Expand Down
133 changes: 133 additions & 0 deletions integrations/openai/tests/unit_tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from databricks_openai import AsyncDatabricksOpenAI, DatabricksOpenAI
from databricks_openai.utils.clients import (
_get_ai_gateway_base_url,
_get_app_url,
_get_authorized_async_http_client,
_get_authorized_http_client,
Expand Down Expand Up @@ -680,3 +681,135 @@ def test_falls_back_to_no_token_when_unset(self):
def test_falls_back_to_no_token_when_empty_string(self):
with patch.dict("os.environ", {"OPENAI_API_KEY": ""}):
assert _get_openai_api_key() == "no-token"


def _mock_httpx_response(status_code: int, json_data: Any = None) -> MagicMock:
"""Create a mock httpx Response."""
response = MagicMock()
response.status_code = status_code
response.json.return_value = json_data or {}
return response


class TestAIGatewayV2Detection:
"""Tests for _get_ai_gateway_base_url."""

def test_returns_base_url_when_endpoints_exist(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.return_value = _mock_httpx_response(
200,
{
"endpoints": [
{
"name": "databricks-claude-sonnet-4-6",
"id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
"created_by": "Databricks",
"ai_gateway_url": "https://12345.ai-gateway.cloud.databricks.com",
}
]
},
)
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result == "https://12345.ai-gateway.cloud.databricks.com/mlflow/v1"
mock_client.get.assert_called_once_with(
"https://test.databricks.com/api/ai-gateway/v2/endpoints"
)

def test_returns_none_on_404(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.return_value = _mock_httpx_response(404)
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result is None

def test_returns_none_on_empty_endpoints(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.return_value = _mock_httpx_response(200, {"endpoints": []})
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result is None

def test_returns_none_on_network_exception(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.side_effect = Exception("Connection refused")
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result is None

def test_returns_none_on_missing_ai_gateway_url(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.return_value = _mock_httpx_response(
200,
{"endpoints": [{"name": "my-endpoint"}]},
)
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result is None

def test_parses_base_url_from_different_workspace(self):
mock_client = MagicMock(spec=httpx.Client)
mock_client.get.return_value = _mock_httpx_response(
200,
{
"endpoints": [
{
"name": "databricks-gpt-5-2",
"ai_gateway_url": "https://ws-123.ai-gateway.us-east-1.cloud.databricks.com",
}
]
},
)
result = _get_ai_gateway_base_url(mock_client, "https://test.databricks.com")
assert result == "https://ws-123.ai-gateway.us-east-1.cloud.databricks.com/mlflow/v1"


class TestDatabricksOpenAIWithGateway:
"""Tests for AI Gateway V2 integration in DatabricksOpenAI and AsyncDatabricksOpenAI."""

@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
def test_gateway_available_uses_gateway_url(self, client_cls_name, mock_workspace_client):
client_cls = (
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
)
with patch(
"databricks_openai.utils.clients._get_ai_gateway_base_url",
return_value="https://12345.ai-gateway.cloud.databricks.com/mlflow/v1",
):
client = client_cls(workspace_client=mock_workspace_client, use_ai_gateway=True)
assert "ai-gateway" in str(client.base_url)
assert "12345.ai-gateway.cloud.databricks.com" in str(client.base_url)

@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
def test_gateway_unavailable_raises_error(self, client_cls_name, mock_workspace_client):
client_cls = (
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
)
with patch(
"databricks_openai.utils.clients._get_ai_gateway_base_url",
return_value=None,
):
with pytest.raises(ValueError, match="use_ai_gateway=True but AI Gateway V2"):
client_cls(workspace_client=mock_workspace_client, use_ai_gateway=True)

@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
def test_gateway_disabled_no_api_call(self, client_cls_name, mock_workspace_client):
client_cls = (
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
)
with patch(
"databricks_openai.utils.clients._get_ai_gateway_base_url",
) as mock_gateway:
client = client_cls(workspace_client=mock_workspace_client, use_ai_gateway=False)
mock_gateway.assert_not_called()
assert "/serving-endpoints/" in str(client.base_url)

@pytest.mark.parametrize("client_cls_name", ["DatabricksOpenAI", "AsyncDatabricksOpenAI"])
def test_explicit_base_url_skips_gateway_check(self, client_cls_name, mock_workspace_client):
client_cls = (
DatabricksOpenAI if client_cls_name == "DatabricksOpenAI" else AsyncDatabricksOpenAI
)
with patch(
"databricks_openai.utils.clients._get_ai_gateway_base_url",
) as mock_gateway:
client = client_cls(
workspace_client=mock_workspace_client,
base_url="https://custom.example.com/v1",
)
mock_gateway.assert_not_called()
assert "custom.example.com" in str(client.base_url)
Loading