Skip to content

Commit 69f0572

Browse files
authored
Add docstrings for _auth.py
The first version is ready for review, thanks)
1 parent 9b86538 commit 69f0572

1 file changed

Lines changed: 153 additions & 0 deletions

File tree

src/yandex_cloud_ml_sdk/_auth.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CreateIamTokenRequest, CreateIamTokenResponse
1919
)
2020
from yandex.cloud.iam.v1.iam_token_service_pb2_grpc import IamTokenServiceStub
21+
from yandex_cloud_ml_sdk._utils.doc import doc_from
2122

2223
if TYPE_CHECKING:
2324
from ._client import AsyncCloudClient
@@ -36,50 +37,82 @@
3637

3738

3839
class BaseAuth(ABC):
40+
"""Abstract base class for authentication methods.
41+
42+
This class defines the interface for obtaining authentication metadata
43+
and checking if the authentication method is applicable from environment
44+
variables.
45+
"""
3946
@abstractmethod
4047
async def get_auth_metadata(
4148
self,
4249
client: AsyncCloudClient,
4350
timeout: float,
4451
lock: asyncio.Lock
4552
) -> tuple[str, str] | None:
53+
"""Get authentication metadata.
54+
55+
:param client: the asynchronous cloud client to use.
56+
:param timeout: timeout, or the maximum time to wait for the request to complete in seconds.
57+
:param lock: an asyncio lock to ensure thread safety.
58+
59+
.. note::
60+
The lock is reused from the client, as it cannot be created in the Auth constructor.
61+
See the client's _lock docstring for details.
62+
"""
4663
# NB: we are can't create lock in Auth constructor, so we a reusing lock from client.
4764
# Look at client._lock doctstring for details.
4865
pass
4966

5067
@classmethod
5168
@abstractmethod
5269
async def applicable_from_env(cls, **_: Any) -> Self | None:
70+
"""Check if this authentication method is applicable from environment variables.
71+
Return an instance of the authentication class if applicable, or None.
72+
"""
5373
pass
5474

5575

5676
class NoAuth(BaseAuth):
5777
@override
78+
@doc_from(BaseAuth.get_auth_metadata)
5879
async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock: asyncio.Lock) -> None:
5980
return None
6081

6182
@override
6283
@classmethod
84+
@doc_from(BaseAuth.applicable_from_env)
6385
async def applicable_from_env(cls, **_: Any) -> None:
6486
return None
6587

6688

6789
class APIKeyAuth(BaseAuth):
90+
"""Authentication method using an API key."""
6891
env_var = 'YC_API_KEY'
6992

7093
def __init__(self, api_key: str):
94+
"""Initialize with an API key.
95+
96+
:param api_key: The API key to use for authentication.
97+
98+
.. note::
99+
If the credential contains a newline character, it may lead to
100+
a GRPC_CALL_ERROR_INVALID_METADATA error which can be difficult to debug.
101+
"""
71102
# NB: here and below:
72103
# if credential with an \n will get into the grpc metadata,
73104
# user will get very interesting GRPC_CALL_ERROR_INVALID_METADATA error
74105
# which very funny to debug
75106
self._api_key = api_key.strip()
76107

77108
@override
109+
@doc_from(BaseAuth.get_auth_metadata)
78110
async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock: asyncio.Lock) -> tuple[str, str]:
79111
return ('authorization', f'Api-Key {self._api_key}')
80112

81113
@override
82114
@classmethod
115+
@doc_from(BaseAuth.applicable_from_env)
83116
async def applicable_from_env(cls, **_: Any) -> Self | None:
84117
api_key = os.getenv(cls.env_var)
85118
if api_key:
@@ -89,22 +122,34 @@ async def applicable_from_env(cls, **_: Any) -> Self | None:
89122

90123

91124
class BaseIAMTokenAuth(BaseAuth):
125+
"""Base class for IAM token-based authentication."""
92126
def __init__(self, token: str | None):
127+
"""Initialize with an IAM token.
128+
129+
:param token: The IAM token to use for authentication. If None, it will be set to None.
130+
"""
93131
self._token = token.strip() if token else token
94132

95133
@override
134+
@doc_from(BaseAuth.get_auth_metadata)
96135
async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock: asyncio.Lock) -> tuple[str, str]:
97136
return ('authorization', f'Bearer {self._token}')
98137

99138

100139
class IAMTokenAuth(BaseIAMTokenAuth):
140+
"""Authentication method using an IAM token."""
101141
env_var = 'YC_IAM_TOKEN'
102142

103143
def __init__(self, token: str):
144+
"""Initialize with an IAM token.
145+
146+
:param token: The IAM token to use for authentication.
147+
"""
104148
super().__init__(token)
105149

106150
@override
107151
@classmethod
152+
@doc_from(BaseAuth.applicable_from_env)
108153
async def applicable_from_env(cls, **_: Any) -> Self | None:
109154
token = os.getenv(cls.env_var)
110155
if token:
@@ -129,16 +174,24 @@ class EnvIAMTokenAuth(BaseIAMTokenAuth):
129174
default_env_var = 'YC_TOKEN'
130175

131176
def __init__(self, env_var_name: str | None = None):
177+
"""
178+
Initializes the authentication method with the specified environment variable name.
179+
180+
If no environment variable name is provided, the default environment variable
181+
(YC_TOKEN) is used.
182+
"""
132183
self._env_var = env_var_name or self.default_env_var
133184
super().__init__(token=None)
134185

135186
@override
187+
@doc_from(BaseAuth.get_auth_metadata)
136188
async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock: asyncio.Lock) -> tuple[str, str]:
137189
self._token = os.environ[self._env_var].strip()
138190
return await super().get_auth_metadata(client=client, timeout=timeout, lock=lock)
139191

140192
@override
141193
@classmethod
194+
@doc_from(BaseAuth.applicable_from_env)
142195
async def applicable_from_env(cls, **_: Any) -> Self | None:
143196
token = os.getenv(cls.default_env_var)
144197
if token:
@@ -148,22 +201,43 @@ async def applicable_from_env(cls, **_: Any) -> Self | None:
148201

149202

150203
class RefresheableIAMTokenAuth(BaseIAMTokenAuth):
204+
"""
205+
Auth method that supports refreshing the IAM token based on a defined refresh period.
206+
207+
This class manages an IAM token that can be refreshed automatically if it has expired,
208+
based on the specified refresh period.
209+
"""
151210
_token_refresh_period = 60 * 60
152211

153212
def __init__(self, token: str | None) -> None:
213+
"""
214+
Initializes the authentication method with the provided token.
215+
216+
Records the issue time of the token if it is provided.
217+
"""
154218
super().__init__(token)
155219
self._issue_time: float | None = None
156220
if self._token is not None:
157221
self._issue_time = time.time()
158222

159223
def _need_for_token(self):
224+
"""
225+
Determines whether a new token is needed based on the current token's status.
226+
227+
A new token is required if:
228+
229+
- the current token is None;
230+
- the issue time is None;
231+
- the time elapsed since the issue time exceeds the token refresh period.
232+
"""
160233
return (
161234
self._token is None or
162235
self._issue_time is None or
163236
time.time() - self._issue_time > self._token_refresh_period
164237
)
165238

166239
@override
240+
@doc_from(BaseAuth.get_auth_metadata)
167241
async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock: asyncio.Lock) -> tuple[str, str]:
168242
if self._need_for_token():
169243
async with lock:
@@ -175,13 +249,30 @@ async def get_auth_metadata(self, client: AsyncCloudClient, timeout: float, lock
175249

176250
@abstractmethod
177251
async def _get_token(self, client: AsyncCloudClient, timeout: float) -> str:
252+
"""
253+
Abstract method to retrieve an OAuth token.
254+
255+
This method must be implemented by subclasses to define how to obtain
256+
an OAuth token asynchronously.
257+
"""
178258
pass
179259

180260

181261
class OAuthTokenAuth(RefresheableIAMTokenAuth):
262+
"""
263+
Auth method that uses an OAuth token for authentication.
264+
265+
This class extends the RefresheableIAMTokenAuth to provide functionality
266+
for managing and using an OAuth token for authentication purposes.
267+
"""
182268
env_var = 'YC_OAUTH_TOKEN'
183269

184270
def __init__(self, token: str):
271+
"""
272+
Initializes the OAuthTokenAuth with the provided OAuth token.
273+
274+
This method also issues a warning regarding the use of OAuth tokens.
275+
"""
185276
warnings.warn(
186277
OAUTH_WARNING,
187278
UserWarning,
@@ -191,6 +282,7 @@ def __init__(self, token: str):
191282

192283
@override
193284
@classmethod
285+
@doc_from(BaseAuth.applicable_from_env)
194286
async def applicable_from_env(cls, **_: Any) -> Self | None:
195287
token = os.getenv(cls.env_var)
196288
if token:
@@ -200,6 +292,12 @@ async def applicable_from_env(cls, **_: Any) -> Self | None:
200292

201293
@override
202294
async def _get_token(self, client: AsyncCloudClient, timeout: float) -> str:
295+
"""
296+
Retrieve an IAM token asynchronously using the specified client.
297+
298+
:param client: an instance of AsyncCloudClient used to make the request.
299+
:param timeout: timeout, or the maximum time to wait for the request to complete in seconds.
300+
"""
203301
request = CreateIamTokenRequest(yandex_passport_oauth_token=self._oauth_token)
204302
async with client.get_service_stub(IamTokenServiceStub, timeout=timeout) as stub:
205303
result = await client.call_service(
@@ -213,15 +311,34 @@ async def _get_token(self, client: AsyncCloudClient, timeout: float) -> str:
213311

214312

215313
class YandexCloudCLIAuth(RefresheableIAMTokenAuth):
314+
"""
315+
Authentication class for Yandex Cloud CLI using IAM tokens.
316+
317+
It handles the initialization and retrieval of IAM tokens
318+
via the Yandex Cloud CLI.
319+
"""
216320
env_var = 'YC_PROFILE'
217321

218322
def __init__(self, token: str | None = None, endpoint: str | None = None, yc_profile: str | None = None):
323+
"""
324+
Initialize the YandexCloudCLIAuth instance.
325+
326+
:param token: the initial IAM token.
327+
:param endpoint: an endpoint for the Yandex Cloud service.
328+
:param yc_profile: a Yandex Cloud profile name.
329+
"""
219330
super().__init__(token)
220331
self._endpoint = endpoint
221332
self._yc_profile = yc_profile
222333

223334
@classmethod
224335
def _build_command(cls, yc_profile: str | None, endpoint: str | None) -> list[str]:
336+
"""
337+
Build the command to create an IAM token using the Yandex Cloud CLI.
338+
339+
:param yc_profile: the Yandex Cloud profile name.
340+
:param endpoint: the endpoint for the Yandex Cloud service.
341+
"""
225342
cmd = ['yc', 'iam', 'create-token', '--no-user-output']
226343
if endpoint:
227344
cmd.extend(['--endpoint', endpoint])
@@ -233,6 +350,11 @@ def _build_command(cls, yc_profile: str | None, endpoint: str | None) -> list[st
233350

234351
@classmethod
235352
async def _check_output(cls, command: list[str]) -> str | None:
353+
"""
354+
Execute a command and check its output.
355+
356+
:param command: a list of command arguments to execute.
357+
"""
236358
process = await asyncio.create_subprocess_exec(
237359
*command,
238360
stdout=subprocess.PIPE,
@@ -249,6 +371,7 @@ async def _check_output(cls, command: list[str]) -> str | None:
249371
return result[-1].decode('utf-8')
250372

251373
@classmethod
374+
@doc_from(BaseAuth.applicable_from_env)
252375
async def applicable_from_env(cls, yc_profile: str | None = None, endpoint: str | None = None, **_: Any) -> Self | None:
253376
if yc_profile is None:
254377
yc_profile = os.getenv(cls.env_var)
@@ -280,6 +403,7 @@ async def applicable_from_env(cls, yc_profile: str | None = None, endpoint: str
280403
)
281404

282405
@override
406+
@doc_from(OAuthTokenAuth._get_token)
283407
async def _get_token(self, client: AsyncCloudClient, timeout: float) -> str:
284408
cmd = self._build_command(self._yc_profile, self._endpoint)
285409
if not (token := await self._check_output(cmd)):
@@ -289,16 +413,28 @@ async def _get_token(self, client: AsyncCloudClient, timeout: float) -> str:
289413

290414

291415
class MetadataAuth(RefresheableIAMTokenAuth):
416+
"""
417+
Authentication class for retrieving IAM tokens from metadata service.
418+
419+
This class retrieves IAM tokens from the Google Cloud metadata service.
420+
"""
292421
env_var = 'YC_METADATA_ADDR'
293422
_headers = {'Metadata-Flavor': 'Google'}
294423
_default_addr = '169.254.169.254'
295424

296425
def __init__(self, token: str | None = None, metadata_url: str | None = None):
426+
"""
427+
Initialize the MetadataAuth instance.
428+
429+
:param token: the initial IAM token.
430+
:param metadata_url: URL for the metadata service.
431+
"""
297432
self._metadata_url: str = metadata_url or self._default_addr
298433
super().__init__(token)
299434

300435
@override
301436
@classmethod
437+
@doc_from(BaseAuth.applicable_from_env)
302438
async def applicable_from_env(cls, **_: Any) -> Self | None:
303439
addr = os.getenv(cls.env_var, cls._default_addr)
304440
url = f'http://{addr}/computeMetadata/v1/instance/service-accounts/default/token'
@@ -314,11 +450,18 @@ async def applicable_from_env(cls, **_: Any) -> Self | None:
314450
return cls(token, url)
315451

316452
@override
453+
@doc_from(OAuthTokenAuth._get_token)
317454
async def _get_token(self, client: AsyncCloudClient | None, timeout: float) -> str:
318455
return await self._request_token(timeout, self._metadata_url)
319456

320457
@classmethod
321458
async def _request_token(cls, timeout: float, metadata_url: str) -> str:
459+
"""
460+
Asynchronously request an IAM access token from the metadata service.
461+
462+
:param timeout: timeout, or the maximum time to wait for the request to complete in seconds.
463+
:param metadata_url: the URL of the metadata service to request the token from.
464+
"""
322465
async with httpx.AsyncClient() as client:
323466
response = await client.get(
324467
metadata_url,
@@ -337,6 +480,16 @@ async def get_auth_provider(
337480
endpoint: str,
338481
yc_profile: str | None,
339482
) -> BaseAuth:
483+
"""
484+
Retrieve an appropriate authentication provider based on the provided auth parameter.
485+
486+
It determines the type of authentication to use based on the input
487+
and returns an instance of a corresponding authentication class.
488+
489+
:param auth: a string representing the authentication token, an instance of BaseAuth, or None.
490+
:param endpoint: the endpoint for the Yandex Cloud service.
491+
:param yc_profile: a Yandex Cloud profile name.
492+
"""
340493
simple_iam_regexp = re.compile(r'^t\d\.')
341494
iam_regexp = re.compile(r't1\.[A-Z0-9a-z_-]+[=]{0,2}\.[A-Z0-9a-z_-]{86}[=]{0,2}')
342495
simple_oauth_regexp = re.compile(r'y[0123]_[-\w]')

0 commit comments

Comments
 (0)