-
Notifications
You must be signed in to change notification settings - Fork 175
Expand file tree
/
Copy pathservice.py
More file actions
527 lines (486 loc) · 21 KB
/
service.py
File metadata and controls
527 lines (486 loc) · 21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
import logging
from collections import ChainMap
from collections.abc import Mapping
from datetime import datetime
from typing import Any, cast
from aiohttp import web
from sqlalchemy import RowMapping
from ai.backend.common.dto.manager.auth.types import AuthTokenType
from ai.backend.common.exception import InvalidAPIParameters
from ai.backend.common.plugin.hook import ALL_COMPLETED, FIRST_COMPLETED, PASSED, HookPluginContext
from ai.backend.common.types import AccessKey
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.config.unified import AuthConfig
from ai.backend.manager.data.auth.types import AuthorizationResult, SSHKeypair
from ai.backend.manager.errors.auth import (
AuthorizationFailed,
EmailAlreadyExistsError,
GroupMembershipNotFoundError,
PasswordExpired,
UserCreationError,
)
from ai.backend.manager.errors.common import (
GenericBadRequest,
GenericForbidden,
InternalServerError,
ObjectNotFound,
RejectedByHook,
)
from ai.backend.manager.models.hasher.types import PasswordInfo
from ai.backend.manager.models.keypair import (
generate_keypair,
generate_ssh_keypair,
validate_ssh_keypair,
)
from ai.backend.manager.models.user import (
INACTIVE_USER_STATUSES,
UserRole,
UserStatus,
compare_to_hashed_password,
)
from ai.backend.manager.repositories.auth.repository import AuthRepository
from ai.backend.manager.services.auth.actions.authorize import (
AuthorizeAction,
AuthorizeActionResult,
)
from ai.backend.manager.services.auth.actions.generate_ssh_keypair import (
GenerateSSHKeypairAction,
GenerateSSHKeypairActionResult,
)
from ai.backend.manager.services.auth.actions.get_role import GetRoleAction, GetRoleActionResult
from ai.backend.manager.services.auth.actions.get_ssh_keypair import (
GetSSHKeypairAction,
GetSSHKeypairActionResult,
)
from ai.backend.manager.services.auth.actions.resolve_access_key_scope import (
ResolveAccessKeyScopeAction,
ResolveAccessKeyScopeResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_scope import (
ResolveUserScopeAction,
ResolveUserScopeResult,
)
from ai.backend.manager.services.auth.actions.signout import SignoutAction, SignoutActionResult
from ai.backend.manager.services.auth.actions.signup import SignupAction, SignupActionResult
from ai.backend.manager.services.auth.actions.update_full_name import (
UpdateFullNameAction,
UpdateFullNameActionResult,
)
from ai.backend.manager.services.auth.actions.update_password import (
UpdatePasswordAction,
UpdatePasswordActionResult,
)
from ai.backend.manager.services.auth.actions.update_password_no_auth import (
UpdatePasswordNoAuthAction,
UpdatePasswordNoAuthActionResult,
)
from ai.backend.manager.services.auth.actions.upload_ssh_keypair import (
UploadSSHKeypairAction,
UploadSSHKeypairActionResult,
)
from ai.backend.manager.utils import check_if_requester_is_eligible_to_act_as_target_user
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
class AuthService:
_hook_plugin_ctx: HookPluginContext
_auth_repository: AuthRepository
_config_provider: ManagerConfigProvider
def __init__(
self,
hook_plugin_ctx: HookPluginContext,
auth_repository: AuthRepository,
config_provider: ManagerConfigProvider,
) -> None:
self._hook_plugin_ctx = hook_plugin_ctx
self._auth_repository = auth_repository
self._config_provider = config_provider
async def get_role(self, action: GetRoleAction) -> GetRoleActionResult:
group_role = None
if action.group_id is not None:
try:
# TODO: per-group role is not yet implemented.
await self._auth_repository.get_group_membership(action.group_id, action.user_id)
group_role = "user"
except GroupMembershipNotFoundError as e:
raise ObjectNotFound(
extra_msg="No such project or you are not the member of it.",
object_name="project (user group)",
) from e
return GetRoleActionResult(
global_role="superadmin" if action.is_superadmin else "user",
domain_role="admin" if action.is_admin else "user",
group_role=group_role,
)
async def authorize(self, action: AuthorizeAction) -> AuthorizeActionResult:
if action.type != AuthTokenType.KEYPAIR:
# other types are not implemented yet.
raise InvalidAPIParameters("Unsupported authorization type")
params = action.hook_params
hook_result = await self._hook_plugin_ctx.dispatch(
"AUTHORIZE",
(action.request, params),
return_when=FIRST_COMPLETED,
)
auth_config = self._config_provider.config.auth
if hook_result.status != PASSED:
raise RejectedByHook.from_hook_result(hook_result)
if hook_result.result:
# Passed one of AUTHORIZED hook
user = hook_result.result
else:
# No AUTHORIZE hook is defined (proceed with normal login)
target_password_info = PasswordInfo(
password=action.password,
algorithm=auth_config.password_hash_algorithm,
rounds=auth_config.password_hash_rounds,
salt_size=auth_config.password_hash_salt_size,
)
user = await self._auth_repository.check_credential_with_migration(
action.domain_name,
action.email,
target_password_info=target_password_info,
)
if user.status == UserStatus.BEFORE_VERIFICATION:
raise AuthorizationFailed("This account needs email verification.")
if user.status in INACTIVE_USER_STATUSES:
raise AuthorizationFailed("User credential mismatch.")
await self._check_password_age(user, auth_config)
user_row = await self._auth_repository.get_user_row_by_uuid(user.uuid)
main_keypair_row = user_row.get_main_keypair_row()
if main_keypair_row is None:
raise AuthorizationFailed("No API keypairs found.")
# [Hooking point for POST_AUTHORIZE]
# The hook handlers should accept a tuple of the request, user, and keypair objects.
hook_result = await self._hook_plugin_ctx.dispatch(
"POST_AUTHORIZE",
(action.request, params, user, main_keypair_row.mapping),
return_when=FIRST_COMPLETED,
)
if hook_result.status != PASSED:
raise RejectedByHook.from_hook_result(hook_result)
if hook_result.result is not None and isinstance(hook_result.result, web.StreamResponse):
return AuthorizeActionResult(
stream_response=hook_result.result,
authorization_result=None,
)
return AuthorizeActionResult(
stream_response=None,
authorization_result=AuthorizationResult(
access_key=main_keypair_row.access_key,
secret_key=main_keypair_row.secret_key or "",
user_id=user.uuid,
role=user.role,
status=user.status,
),
)
async def signup(self, action: SignupAction) -> SignupActionResult:
params = action.hook_params
hook_result = await self._hook_plugin_ctx.dispatch(
"PRE_SIGNUP",
(params,),
return_when=ALL_COMPLETED,
)
if hook_result.status != PASSED:
raise RejectedByHook.from_hook_result(hook_result)
# Merge the hook results as a single map.
hook_results = cast(list[Mapping[str, Any]], hook_result.result or [])
# Convert Mapping to dict for ChainMap compatibility
user_data_overriden: ChainMap[str, Any] = ChainMap(*[
dict(result) for result in hook_results
])
# [Hooking point for VERIFY_PASSWORD_FORMAT with the ALL_COMPLETED requirement]
# The hook handlers should accept the request and whole ``params` dict.
# They should return None if the validation is successful and raise the
# Reject error otherwise.
hook_result = await self._hook_plugin_ctx.dispatch(
"VERIFY_PASSWORD_FORMAT",
(action.request, params),
return_when=ALL_COMPLETED,
)
if hook_result.status != PASSED:
hook_result.reason = hook_result.reason or "invalid password format"
raise RejectedByHook.from_hook_result(hook_result)
# Check if email already exists.
if await self._auth_repository.check_email_exists(action.email):
raise EmailAlreadyExistsError("Email already exists")
# Create a user.
# Create PasswordInfo for the new user's password
auth_config = self._config_provider.config.auth
password_info = PasswordInfo(
password=action.password,
algorithm=auth_config.password_hash_algorithm,
rounds=auth_config.password_hash_rounds,
salt_size=auth_config.password_hash_salt_size,
)
data = {
"domain_name": action.domain_name,
"username": action.username if action.username is not None else action.email,
"email": action.email,
"password": password_info, # Pass PasswordInfo object
"need_password_change": False,
"full_name": action.full_name if action.full_name is not None else "",
"description": action.description if action.description is not None else "",
"status": UserStatus.INACTIVE,
"status_info": "user-signup",
"role": UserRole.USER,
"integration_id": None,
"resource_policy": "default",
"sudo_session_enabled": False,
}
if user_data_overriden:
for key, val in user_data_overriden.items():
if (
key in data # take only valid fields
and key != "resource_policy" # resource_policy in user_data is for keypair
):
data[key] = val
# Create user's first access_key and secret_key.
ak, sk = generate_keypair()
resource_policy = user_data_overriden.get("resource_policy", "default")
kp_data = {
"user_id": action.email,
"access_key": ak,
"secret_key": sk,
"is_active": data.get("status") == UserStatus.ACTIVE,
"is_admin": False,
"resource_policy": resource_policy,
"rate_limit": 1000,
"num_queries": 0,
}
# Add user to the default group.
group_name = user_data_overriden.get("group", "default")
try:
user = await self._auth_repository.create_user_with_keypair(
user_data=data,
keypair_data=kp_data,
group_name=group_name,
domain_name=action.domain_name,
)
except UserCreationError as e:
raise InternalServerError("Error creating user account") from e
# [Hooking point for POST_SIGNUP as one-way notification]
# The hook handlers should accept a tuple of the user email,
# the new user's UUID, and a dict with initial user's preferences.
initial_user_prefs = {
"lang": action.request.headers.get("Accept-Language", "en-us").split(",")[0].lower(),
}
await self._hook_plugin_ctx.notify(
"POST_SIGNUP",
(action.email, user.uuid, initial_user_prefs),
)
return SignupActionResult(
user_id=user.uuid,
access_key=ak,
secret_key=sk,
)
async def signout(self, action: SignoutAction) -> SignoutActionResult:
if action.email != action.requester_email:
raise GenericForbidden("Not the account owner")
email = action.email
await self._auth_repository.check_credential_without_migration(
action.domain_name,
email,
action.password,
)
await self._auth_repository.deactivate_user_and_keypairs(email)
return SignoutActionResult(success=True)
async def update_full_name(self, action: UpdateFullNameAction) -> UpdateFullNameActionResult:
await self._auth_repository.update_user_full_name(
action.email, action.domain_name, action.full_name
)
return UpdateFullNameActionResult(success=True)
async def update_password(self, action: UpdatePasswordAction) -> UpdatePasswordActionResult:
domain_name = action.domain_name
email = action.email
log_fmt = "AUTH.UPDATE_PASSWORD(d:{}, email:{})"
log_args = (domain_name, email)
if action.new_password != action.new_password_confirm:
log.info(log_fmt + ": new password mismtach", *log_args)
return UpdatePasswordActionResult(
success=False,
message="new password mismatch",
)
try:
await self._auth_repository.check_credential_without_migration(
domain_name,
email,
action.old_password,
)
except AuthorizationFailed as e:
log.info(log_fmt + ": old password mismatch", *log_args)
raise AuthorizationFailed("Old password mismatch") from e
# [Hooking point for VERIFY_PASSWORD_FORMAT with the ALL_COMPLETED requirement]
# The hook handlers should accept the request and whole ``params` dict.
# They should return None if the validation is successful and raise the
# Reject error otherwise.
hook_result = await self._hook_plugin_ctx.dispatch(
"VERIFY_PASSWORD_FORMAT",
(action.request, action.hook_params),
return_when=ALL_COMPLETED,
)
if hook_result.status != PASSED:
hook_result.reason = hook_result.reason or "invalid password format"
raise RejectedByHook.from_hook_result(hook_result)
# Create PasswordInfo with config values
auth_config = self._config_provider.config.auth
password_info = PasswordInfo(
password=action.new_password,
algorithm=auth_config.password_hash_algorithm,
rounds=auth_config.password_hash_rounds,
salt_size=auth_config.password_hash_salt_size,
)
await self._auth_repository.update_user_password(email, password_info)
return UpdatePasswordActionResult(
success=True,
message="Password updated successfully",
)
async def update_password_no_auth(
self, action: UpdatePasswordNoAuthAction
) -> UpdatePasswordNoAuthActionResult:
auth_config = self._config_provider.config.auth
if auth_config.max_password_age is None:
raise GenericBadRequest("Unsupported function.")
checked_user = await self._auth_repository.check_credential_without_migration(
action.domain_name,
action.email,
password=action.current_password,
)
new_password = action.new_password
if compare_to_hashed_password(new_password, checked_user["password"]):
raise AuthorizationFailed("Cannot update to the same password as an existing password.")
# [Hooking point for VERIFY_PASSWORD_FORMAT with the ALL_COMPLETED requirement]
# The hook handlers should accept the request and whole ``params` dict.
# They should return None if the validation is successful and raise the
# Reject error otherwise.
hook_result = await self._hook_plugin_ctx.dispatch(
"VERIFY_PASSWORD_FORMAT",
(action.request, action.hook_params),
return_when=ALL_COMPLETED,
)
if hook_result.status != PASSED:
hook_result.reason = hook_result.reason or "invalid password format"
raise RejectedByHook.from_hook_result(hook_result)
password_info = PasswordInfo(
password=new_password,
algorithm=auth_config.password_hash_algorithm,
rounds=auth_config.password_hash_rounds,
salt_size=auth_config.password_hash_salt_size,
)
changed_at = await self._auth_repository.update_user_password_by_uuid(
checked_user["uuid"], password_info
)
return UpdatePasswordNoAuthActionResult(
user_id=checked_user["uuid"],
password_changed_at=changed_at,
)
async def get_ssh_keypair(self, action: GetSSHKeypairAction) -> GetSSHKeypairActionResult:
pubkey = await self._auth_repository.get_ssh_public_key(action.access_key)
return GetSSHKeypairActionResult(public_key=pubkey or "", access_key=action.access_key)
async def generate_ssh_keypair(
self, action: GenerateSSHKeypairAction
) -> GenerateSSHKeypairActionResult:
pubkey, privkey = generate_ssh_keypair()
await self._auth_repository.update_ssh_keypair(action.access_key, pubkey, privkey)
return GenerateSSHKeypairActionResult(
ssh_keypair=SSHKeypair(
ssh_public_key=pubkey,
ssh_private_key=privkey,
),
user_id=action.user_id,
)
async def upload_ssh_keypair(
self, action: UploadSSHKeypairAction
) -> UploadSSHKeypairActionResult:
privkey = action.private_key
pubkey = action.public_key
is_valid, err_msg = validate_ssh_keypair(privkey, pubkey)
if not is_valid:
raise InvalidAPIParameters(err_msg)
await self._auth_repository.update_ssh_keypair(action.access_key, pubkey, privkey)
return UploadSSHKeypairActionResult(
ssh_keypair=SSHKeypair(
ssh_public_key=pubkey,
ssh_private_key=privkey,
),
user_id=action.user_id,
)
async def resolve_access_key_scope(
self, action: ResolveAccessKeyScopeAction
) -> ResolveAccessKeyScopeResult:
requester_ak = AccessKey(action.requester_access_key)
if (
action.owner_access_key is None
or action.owner_access_key == action.requester_access_key
):
return ResolveAccessKeyScopeResult(
requester_access_key=requester_ak,
owner_access_key=requester_ak,
)
owner_ak = AccessKey(action.owner_access_key)
try:
(
owner_domain,
owner_role,
) = await self._auth_repository.get_delegation_target_by_access_key(
action.owner_access_key,
)
except ValueError as e:
raise InvalidAPIParameters(str(e)) from e
try:
check_if_requester_is_eligible_to_act_as_target_user(
action.requester_role,
action.requester_domain,
owner_role,
owner_domain,
)
except RuntimeError as e:
raise GenericForbidden(str(e)) from e
return ResolveAccessKeyScopeResult(
requester_access_key=requester_ak,
owner_access_key=owner_ak,
)
async def resolve_user_scope(self, action: ResolveUserScopeAction) -> ResolveUserScopeResult:
if action.owner_user_email is None:
return ResolveUserScopeResult(
owner_uuid=action.requester_uuid,
owner_role=action.requester_role,
)
if not action.is_superadmin:
raise InvalidAPIParameters("Only superadmins may have user scopes.")
try:
(
owner_uuid,
owner_role,
owner_domain,
) = await self._auth_repository.get_delegation_target_by_email(
action.owner_user_email,
)
except ValueError as e:
raise InvalidAPIParameters(str(e)) from e
try:
check_if_requester_is_eligible_to_act_as_target_user(
action.requester_role,
action.requester_domain,
owner_role,
owner_domain,
)
except RuntimeError as e:
raise GenericForbidden(str(e)) from e
return ResolveUserScopeResult(
owner_uuid=owner_uuid,
owner_role=owner_role,
)
async def _check_password_age(self, user: RowMapping, auth_config: AuthConfig | None) -> None:
if (
auth_config is not None
and (max_password_age := auth_config.max_password_age) is not None
):
password_changed_at: datetime | None = user.password_changed_at
if password_changed_at is None:
return # Skip check if password_changed_at is not set
current_dt: datetime = await self._auth_repository.get_current_time()
if password_changed_at + max_password_age < current_dt:
# Force user to update password
raise PasswordExpired(
extra_msg=f"Password expired on {password_changed_at + max_password_age}."
)