-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy path_client.py
More file actions
314 lines (263 loc) · 11.4 KB
/
_client.py
File metadata and controls
314 lines (263 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# pylint: disable=too-many-instance-attributes
from __future__ import annotations
import sys
import uuid
from contextlib import asynccontextmanager
from typing import AsyncIterator, Literal, Protocol, Sequence, TypeVar, cast
import grpc
import grpc.aio
import httpx
from google.protobuf.message import Message
from yandex.cloud.endpoint.api_endpoint_service_pb2 import ListApiEndpointsRequest # pylint: disable=no-name-in-module
from yandex.cloud.endpoint.api_endpoint_service_pb2_grpc import ApiEndpointServiceStub
from ._auth import BaseAuth, get_auth_provider
from ._exceptions import AioRpcError, UnknownEndpointError
from ._retry import RETRY_KIND_METADATA_KEY, RetryKind, RetryPolicy
from ._types.misc import PathLike, coerce_path
from ._utils.lock import LazyLock
from ._utils.proto import service_for_ctor
class StubType(Protocol):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
...
_T = TypeVar('_T', bound=StubType)
_D = TypeVar('_D', bound=Message)
def _get_user_agent() -> str:
from . import __version__ # pylint: disable=import-outside-toplevel,cyclic-import
# NB: grpc breaks in case of using \t instead of space
return (
f'yandex-cloud-ml-sdk/{__version__} '
f'python/{sys.version_info.major}.{sys.version_info.minor}'
)
class AsyncCloudClient:
def __init__(
self,
*,
endpoint: str | None,
auth: BaseAuth | str | None,
service_map: dict[str, str],
interceptors: Sequence[grpc.aio.ClientInterceptor] | None,
yc_profile: str | None,
retry_policy: RetryPolicy,
enable_server_data_logging: bool | None,
verify: PathLike | bool | None,
):
self._endpoint = endpoint
self._auth = auth
self._auth_provider: BaseAuth | None = None
self._yc_profile = yc_profile
self._service_map_override: dict[str, str] = service_map
self._service_map: dict[str, str] = {}
self._interceptors = (
(tuple(interceptors) if interceptors else ()) +
retry_policy.get_interceptors()
)
self._channels: dict[type[StubType], grpc.aio.Channel] = {}
self._endpoints: dict[type[StubType], str] = {}
self._auth_lock = LazyLock()
self._channels_lock = LazyLock()
self._user_agent = _get_user_agent()
self._enable_server_data_logging = enable_server_data_logging
self._verify = verify if verify is not None else True
async def _init_service_map(self, timeout: float):
metadata = await self._get_metadata(auth_required=False, timeout=timeout, retry_kind=RetryKind.SINGLE)
if not self._endpoint:
raise RuntimeError('This method should be never called while endpoint=None')
channel = self._new_channel(self._endpoint)
async with channel:
stub = ApiEndpointServiceStub(channel)
response = await stub.List(
ListApiEndpointsRequest(),
timeout=timeout,
metadata=metadata,
) # type: ignore[misc]
for endpoint in response.endpoints:
self._service_map[endpoint.id] = endpoint.address
async def _discover_service_endpoint(
self,
service_name: str,
stub_class: type[StubType],
timeout: float
) -> str:
endpoint: str | None
# TODO: add a validation for unknown services in override
if endpoint := self._service_map_override.get(service_name):
return endpoint
if self._endpoint is None:
raise UnknownEndpointError(
"due to `endpoint` SDK param explicitly set to `None` you need to define "
f"{service_name!r} endpoint manually at `service_map` SDK param"
)
if not self._service_map:
await self._init_service_map(timeout=timeout)
if endpoint := self._service_map.get(service_name):
return endpoint
raise UnknownEndpointError(f'failed to find endpoint for {service_name=} and {stub_class=}')
async def _get_metadata(
self,
*,
auth_required: bool,
timeout: float,
retry_kind: RetryKind = RetryKind.NONE,
) -> tuple[tuple[str, str], ...]:
metadata: tuple[tuple[str, str], ...] = (
(RETRY_KIND_METADATA_KEY, retry_kind.name),
('x-client-request-id', str(uuid.uuid4())),
)
if self._enable_server_data_logging is not None:
enable_server_data_logging = "true" if self._enable_server_data_logging else "false"
metadata += (
("x-data-logging-enabled", enable_server_data_logging),
)
if not auth_required:
return metadata
if self._auth_provider is None:
async with self._auth_lock():
if self._auth_provider is None:
self._auth_provider = await get_auth_provider(
auth=self._auth,
endpoint=self._endpoint,
yc_profile=self._yc_profile
)
# in case of self._auth=NoAuth(), it will return None
# and it is might be okay: for local installations and on-premises
auth = await self._auth_provider.get_auth_metadata(client=self, timeout=timeout, lock=self._auth_lock())
if auth:
return metadata + (auth, )
return metadata
def _get_options(self) -> tuple[tuple[str, str], ...]:
return (
("grpc.primary_user_agent", self._user_agent),
)
def _new_channel(self, endpoint: str) -> grpc.aio.Channel:
if self._verify is False:
return grpc.aio.insecure_channel(
endpoint,
interceptors=self._interceptors,
options=self._get_options(),
)
if self._verify is True:
credentials = grpc.ssl_channel_credentials()
else:
path = coerce_path(self._verify)
cert = path.read_bytes()
credentials = grpc.ssl_channel_credentials(cert)
return grpc.aio.secure_channel(
endpoint,
credentials,
interceptors=self._interceptors,
options=self._get_options(),
)
async def _get_channel(
self,
stub_class: type[_T],
timeout: float,
service_name: str | None = None,
) -> grpc.aio.Channel:
if stub_class in self._channels:
return self._channels[stub_class]
async with self._channels_lock():
if stub_class in self._channels:
return self._channels[stub_class]
service_name = service_name if service_name else service_for_ctor(stub_class)
endpoint = await self._discover_service_endpoint(
service_name=service_name,
stub_class=stub_class,
timeout=timeout
)
self._endpoints[stub_class] = endpoint
channel = self._channels[stub_class] = self._new_channel(endpoint)
return channel
@asynccontextmanager
async def get_service_stub(
self,
stub_class: type[_T],
timeout: float,
service_name: str | None = None
) -> AsyncIterator[_T]:
# NB: right now get_service_stub is asynccontextmanager and it is unnecessary,
# but in future if we will make some ChannelPool, it could be handy to know,
# when "user" releases resource
channel = await self._get_channel(stub_class, timeout, service_name=service_name)
try:
yield stub_class(channel)
except grpc.aio.AioRpcError as original:
# .with_traceback(...) from None allows to mimic
# original exception without increasing traceback with an
# extra info, like
# "During handling of the above exception, another exception occurred"
# or # "The above exception was the direct cause of the following exception"
raise AioRpcError.from_base_rpc_error(
original,
endpoint=self._endpoints[stub_class],
auth=self._auth_provider,
stub_class=stub_class,
).with_traceback(original.__traceback__) from None
async def call_service_stream(
self,
service: grpc.aio.UnaryStreamMultiCallable | grpc.UnaryStreamMultiCallable,
request: Message,
timeout: float,
expected_type: type[_D], # pylint: disable=unused-argument
auth: bool = True,
retry_kind: Literal[RetryKind.NONE, RetryKind.SINGLE, RetryKind.CONTINUATION] = RetryKind.SINGLE,
) -> AsyncIterator[_D]:
# NB: when you instantiate a stub class on a async or sync channel, you got
# "async" of "sync" stub, and it have relevant methods like __aiter__
# and such. But from typing perspective it have no difference,
# it just a stub object.
# Auto-generated stubs for grpc saying, that attribute stub.Service returns
# grpc.Unary...Multicallable, not async one, but in real life
# we are using only async stubs in this project.
# In ideal world we need to do something like
# cast(grpc.aio.UnaryStreamMultiCallable, stub.Service) at usage place,
# but it is too lot places to insert this cast, so I'm doing it here.
service = cast(grpc.aio.UnaryStreamMultiCallable, service)
metadata = await self._get_metadata(auth_required=auth, timeout=timeout, retry_kind=retry_kind)
call = service(request, metadata=metadata, timeout=timeout)
try:
async for response in call:
yield cast(_D, response)
except GeneratorExit:
call.cancel()
raise
async def call_service(
self,
service: grpc.aio.UnaryUnaryMultiCallable | grpc.UnaryUnaryMultiCallable,
request: Message,
timeout: float,
expected_type: type[_D], # pylint: disable=unused-argument
auth: bool = True,
retry_kind: Literal[RetryKind.NONE, RetryKind.SINGLE] = RetryKind.SINGLE,
) -> _D:
service = cast(grpc.aio.UnaryUnaryMultiCallable, service)
metadata = await self._get_metadata(auth_required=auth, timeout=timeout, retry_kind=retry_kind)
result = await service(
request,
metadata=metadata,
timeout=timeout,
wait_for_ready=True,
)
return cast(_D, result)
async def stream_service_stream(
self,
service: grpc.aio.StreamStreamMultiCallable | grpc.StreamStreamMultiCallable,
requests: AsyncIterator[Message],
timeout: float,
expected_type: type[_D], # pylint: disable=unused-argument
auth: bool = True,
) -> AsyncIterator[_D]:
service = cast(grpc.aio.StreamStreamMultiCallable, service)
metadata = await self._get_metadata(auth_required=auth, timeout=timeout)
call = service(requests, metadata=metadata, timeout=timeout)
async for response in call:
yield cast(_D, response)
@asynccontextmanager
async def httpx(self) -> AsyncIterator[httpx.AsyncClient]:
headers = {'user-agent': self._user_agent}
verify: str | bool
if isinstance(self._verify, bool):
verify = self._verify
else:
verify = str(coerce_path(self._verify))
async with httpx.AsyncClient(headers=headers, verify=verify) as client:
yield client