Skip to content

Commit 124f50b

Browse files
authored
Improve service_map SDK parameter (#130)
1 parent cae2d41 commit 124f50b

33 files changed

+4353
-2084
lines changed

src/yandex_cloud_ml_sdk/_auth.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@ async def get_auth_metadata(
5252
""":meta private:"""
5353
# NB: we are can't create lock in Auth constructor, so we a reusing lock from client.
5454
# Look at client._lock doctstring for details.
55-
pass
5655

5756
@classmethod
5857
@abstractmethod
5958
async def applicable_from_env(cls, **_: Any) -> Self | None:
6059
""":meta private:"""
61-
pass
6260

6361

6462
class NoAuth(BaseAuth):
@@ -402,7 +400,7 @@ async def _request_token(cls, timeout: float, metadata_url: str) -> str:
402400
async def get_auth_provider(
403401
*,
404402
auth: str | BaseAuth | None,
405-
endpoint: str,
403+
endpoint: str | None,
406404
yc_profile: str | None,
407405
) -> BaseAuth:
408406
"""

src/yandex_cloud_ml_sdk/_client.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from yandex.cloud.endpoint.api_endpoint_service_pb2_grpc import ApiEndpointServiceStub
1515

1616
from ._auth import BaseAuth, get_auth_provider
17-
from ._exceptions import AioRpcError
17+
from ._exceptions import AioRpcError, UnknownEndpointError
1818
from ._retry import RETRY_KIND_METADATA_KEY, RetryKind, RetryPolicy
1919
from ._types.misc import PathLike, coerce_path
2020
from ._utils.lock import LazyLock
@@ -44,7 +44,7 @@ class AsyncCloudClient:
4444
def __init__(
4545
self,
4646
*,
47-
endpoint: str,
47+
endpoint: str | None,
4848
auth: BaseAuth | str | None,
4949
service_map: dict[str, str],
5050
interceptors: Sequence[grpc.aio.ClientInterceptor] | None,
@@ -78,6 +78,10 @@ def __init__(
7878

7979
async def _init_service_map(self, timeout: float):
8080
metadata = await self._get_metadata(auth_required=False, timeout=timeout, retry_kind=RetryKind.SINGLE)
81+
82+
if not self._endpoint:
83+
raise RuntimeError('This method should be never called while endpoint=None')
84+
8185
channel = self._new_channel(self._endpoint)
8286
async with channel:
8387
stub = ApiEndpointServiceStub(channel)
@@ -89,8 +93,30 @@ async def _init_service_map(self, timeout: float):
8993
for endpoint in response.endpoints:
9094
self._service_map[endpoint.id] = endpoint.address
9195

96+
async def _discover_service_endpoint(
97+
self,
98+
service_name: str,
99+
stub_class: type[StubType],
100+
timeout: float
101+
) -> str:
102+
endpoint: str | None
92103
# TODO: add a validation for unknown services in override
93-
self._service_map.update(self._service_map_override)
104+
if endpoint := self._service_map_override.get(service_name):
105+
return endpoint
106+
107+
if self._endpoint is None:
108+
raise UnknownEndpointError(
109+
"due to `endpoint` SDK param explicitly set to `None` you need to define "
110+
f"{service_name!r} endpoint manually at `service_map` SDK param"
111+
)
112+
113+
if not self._service_map:
114+
await self._init_service_map(timeout=timeout)
115+
116+
if endpoint := self._service_map.get(service_name):
117+
return endpoint
118+
119+
raise UnknownEndpointError(f'failed to find endpoint for {service_name=} and {stub_class=}')
94120

95121
async def _get_metadata(
96122
self,
@@ -172,19 +198,12 @@ async def _get_channel(
172198
return self._channels[stub_class]
173199

174200
service_name = service_name if service_name else service_for_ctor(stub_class)
175-
if not self._service_map:
176-
await self._init_service_map(timeout=timeout)
177-
178-
if not (endpoint := self._service_map.get(service_name)):
179-
# NB: this fix will work if service_map will change ai-assistant to ai-assistants
180-
# (and retrospectively if user will stuck with this version)
181-
# and if _service_for_ctor will change ai-assistants to ai-assistant
182-
if service_name in ('ai-assistant', 'ai-assistants'):
183-
service_name = 'ai-assistant' if service_name == 'ai-assistants' else 'ai-assistants'
184-
if not (endpoint := self._service_map.get(service_name)):
185-
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')
186-
else:
187-
raise ValueError(f'failed to find endpoint for {service_name=} and {stub_class=}')
201+
202+
endpoint = await self._discover_service_endpoint(
203+
service_name=service_name,
204+
stub_class=stub_class,
205+
timeout=timeout
206+
)
188207

189208
self._endpoints[stub_class] = endpoint
190209
channel = self._channels[stub_class] = self._new_channel(endpoint)

src/yandex_cloud_ml_sdk/_exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
class YCloudMLError(Exception):
1919
pass
2020

21+
class UnknownEndpointError(YCloudMLError):
22+
pass
23+
2124

2225
class RunError(YCloudMLError):
2326
def __init__(self, code: int, message: str, details: list[Any] | None, operation_id: str):

src/yandex_cloud_ml_sdk/_sdk.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self,
6161
*,
6262
folder_id: str,
63-
endpoint: UndefinedOr[str] = UNDEFINED,
63+
endpoint: UndefinedOr[str] | None = UNDEFINED,
6464
auth: UndefinedOr[str | BaseAuth] = UNDEFINED,
6565
retry_policy: UndefinedOr[RetryPolicy] = UNDEFINED,
6666
yc_profile: UndefinedOr[str] = UNDEFINED,
@@ -76,13 +76,15 @@ def __init__(
7676
:type folder_id: str
7777
:param endpoint: domain:port pair for Yandex Cloud API or any other
7878
grpc compatible target.
79+
In case of ``None`` passed it turns off service endpoint discovery mechanism
80+
and requires ``service_map`` to be passed.
7981
:type endpoint: str
8082
:param auth: string with API Key, IAM token or one of yandex_cloud_ml_sdk.auth objects;
8183
in case of default Undefined value, there will be a mechanism to get token
8284
from environment
8385
:type api_key | BaseAuth: str
8486
:param service_map: a way to redefine endpoints for one or more cloud subservices
85-
with a format of dict {service_name: service_address}.
87+
with a format of dict ``{"service_name": "service_address"}``.
8688
:type service_map: Dict[str, str]
8789
:param enable_server_data_logging: when passed bool, we will add
8890
`x-data-logging-enabled: <value>` to all of requests, which will
@@ -145,7 +147,7 @@ def _init_domains(self) -> None:
145147
resource = member(name=member_name, sdk=self)
146148
setattr(self, member_name, resource)
147149

148-
def _get_endpoint(self, endpoint: UndefinedOr[str]) -> str:
150+
def _get_endpoint(self, endpoint: UndefinedOr[str] | None) -> str | None:
149151
"""Retrieves the API endpoint.
150152
151153
If the endpoint is defined, it will be returned. Otherwise, it checks for

src/yandex_cloud_ml_sdk/exceptions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from ._exceptions import (
4-
AioRpcError, AsyncOperationError, DatasetValidationError, RunError, TuningError, WrongAsyncOperationStatusError,
5-
YCloudMLError
4+
AioRpcError, AsyncOperationError, DatasetValidationError, RunError, TuningError, UnknownEndpointError,
5+
WrongAsyncOperationStatusError, YCloudMLError
66
)
77

88
__all__ = [
@@ -13,4 +13,5 @@
1313
'TuningError',
1414
'WrongAsyncOperationStatusError',
1515
'YCloudMLError',
16+
'UnknownEndpointError',
1617
]

0 commit comments

Comments
 (0)