Skip to content

Commit c98d740

Browse files
kristapraticostainless-app[bot]
authored andcommitted
fix(azure): azure_deployment use with realtime + non-deployment-based APIs (#2154)
* support realtime with azure_deployment * lint * use rsplit * switch approach: save copy of the original url * save azure_endpoint as it was given * docstring * format * remove unnecessary check + add test * fix for websocket_base_url * add another test
1 parent ba2a8a0 commit c98d740

File tree

3 files changed

+637
-29
lines changed

3 files changed

+637
-29
lines changed

src/openai/lib/azure.py

+56-11
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __init__(self) -> None:
4949

5050

5151
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
52+
_azure_endpoint: httpx.URL | None
53+
_azure_deployment: str | None
54+
5255
@override
5356
def _build_request(
5457
self,
@@ -58,11 +61,29 @@ def _build_request(
5861
) -> httpx.Request:
5962
if options.url in _deployments_endpoints and is_mapping(options.json_data):
6063
model = options.json_data.get("model")
61-
if model is not None and not "/deployments" in str(self.base_url):
64+
if model is not None and "/deployments" not in str(self.base_url.path):
6265
options.url = f"/deployments/{model}{options.url}"
6366

6467
return super()._build_request(options, retries_taken=retries_taken)
6568

69+
@override
70+
def _prepare_url(self, url: str) -> httpx.URL:
71+
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
72+
and the API feature being called is **not** a deployments-based endpoint
73+
(i.e. requires /deployments/deployment-name in the URL path).
74+
"""
75+
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
76+
merge_url = httpx.URL(url)
77+
if merge_url.is_relative_url:
78+
merge_raw_path = (
79+
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
80+
)
81+
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)
82+
83+
return merge_url
84+
85+
return super()._prepare_url(url)
86+
6687

6788
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
6889
@overload
@@ -160,8 +181,8 @@ def __init__(
160181
161182
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
162183
163-
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
164-
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
184+
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
185+
Not supported with Assistants APIs.
165186
"""
166187
if api_key is None:
167188
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
@@ -224,6 +245,8 @@ def __init__(
224245
self._api_version = api_version
225246
self._azure_ad_token = azure_ad_token
226247
self._azure_ad_token_provider = azure_ad_token_provider
248+
self._azure_deployment = azure_deployment if azure_endpoint else None
249+
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
227250

228251
@override
229252
def copy(
@@ -307,20 +330,30 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
307330

308331
return options
309332

310-
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
333+
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
311334
auth_headers = {}
312335
query = {
313336
**extra_query,
314337
"api-version": self._api_version,
315-
"deployment": model,
338+
"deployment": self._azure_deployment or model,
316339
}
317340
if self.api_key != "<missing API key>":
318341
auth_headers = {"api-key": self.api_key}
319342
else:
320343
token = self._get_azure_ad_token()
321344
if token:
322345
auth_headers = {"Authorization": f"Bearer {token}"}
323-
return query, auth_headers
346+
347+
if self.websocket_base_url is not None:
348+
base_url = httpx.URL(self.websocket_base_url)
349+
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
350+
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
351+
else:
352+
base_url = self._prepare_url("/realtime")
353+
realtime_url = base_url.copy_with(scheme="wss")
354+
355+
url = realtime_url.copy_with(params={**query})
356+
return url, auth_headers
324357

325358

326359
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@@ -422,8 +455,8 @@ def __init__(
422455
423456
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
424457
425-
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
426-
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
458+
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
459+
Not supported with Assistants APIs.
427460
"""
428461
if api_key is None:
429462
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
@@ -486,6 +519,8 @@ def __init__(
486519
self._api_version = api_version
487520
self._azure_ad_token = azure_ad_token
488521
self._azure_ad_token_provider = azure_ad_token_provider
522+
self._azure_deployment = azure_deployment if azure_endpoint else None
523+
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
489524

490525
@override
491526
def copy(
@@ -571,17 +606,27 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp
571606

572607
return options
573608

574-
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
609+
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
575610
auth_headers = {}
576611
query = {
577612
**extra_query,
578613
"api-version": self._api_version,
579-
"deployment": model,
614+
"deployment": self._azure_deployment or model,
580615
}
581616
if self.api_key != "<missing API key>":
582617
auth_headers = {"api-key": self.api_key}
583618
else:
584619
token = await self._get_azure_ad_token()
585620
if token:
586621
auth_headers = {"Authorization": f"Bearer {token}"}
587-
return query, auth_headers
622+
623+
if self.websocket_base_url is not None:
624+
base_url = httpx.URL(self.websocket_base_url)
625+
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
626+
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
627+
else:
628+
base_url = self._prepare_url("/realtime")
629+
realtime_url = base_url.copy_with(scheme="wss")
630+
631+
url = realtime_url.copy_with(params={**query})
632+
return url, auth_headers

src/openai/resources/beta/realtime/realtime.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,15 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
324324
extra_query = self.__extra_query
325325
auth_headers = self.__client.auth_headers
326326
if is_async_azure_client(self.__client):
327-
extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
328-
329-
url = self._prepare_url().copy_with(
330-
params={
331-
**self.__client.base_url.params,
332-
"model": self.__model,
333-
**extra_query,
334-
},
335-
)
327+
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
328+
else:
329+
url = self._prepare_url().copy_with(
330+
params={
331+
**self.__client.base_url.params,
332+
"model": self.__model,
333+
**extra_query,
334+
},
335+
)
336336
log.debug("Connecting to %s", url)
337337
if self.__websocket_connection_options:
338338
log.debug("Connection options: %s", self.__websocket_connection_options)
@@ -506,15 +506,15 @@ def __enter__(self) -> RealtimeConnection:
506506
extra_query = self.__extra_query
507507
auth_headers = self.__client.auth_headers
508508
if is_azure_client(self.__client):
509-
extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
510-
511-
url = self._prepare_url().copy_with(
512-
params={
513-
**self.__client.base_url.params,
514-
"model": self.__model,
515-
**extra_query,
516-
},
517-
)
509+
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
510+
else:
511+
url = self._prepare_url().copy_with(
512+
params={
513+
**self.__client.base_url.params,
514+
"model": self.__model,
515+
**extra_query,
516+
},
517+
)
518518
log.debug("Connecting to %s", url)
519519
if self.__websocket_connection_options:
520520
log.debug("Connection options: %s", self.__websocket_connection_options)

0 commit comments

Comments
 (0)