diff --git a/changes/11049.enhance.md b/changes/11049.enhance.md new file mode 100644 index 00000000000..213c6e6e322 --- /dev/null +++ b/changes/11049.enhance.md @@ -0,0 +1 @@ +Drop `owner_id` from read/control session action dataclasses and resolve the requester via `current_user()` in the service layer. `SessionLifecycleManager` receives `UserRepository` via constructor injection. diff --git a/changes/BA-5650-D.misc.md b/changes/BA-5650-D.misc.md deleted file mode 100644 index fac69564e40..00000000000 --- a/changes/BA-5650-D.misc.md +++ /dev/null @@ -1 +0,0 @@ -Collapse `SessionRepository` / `SessionDBSource` signatures to take `owner_id: UUID` instead of `owner_access_key: AccessKey`. No external behavior change; downstream service callers are updated in a later slice. diff --git a/src/ai/backend/manager/api/adapters/session.py b/src/ai/backend/manager/api/adapters/session.py index e04aa90abb9..ddb432c2d87 100644 --- a/src/ai/backend/manager/api/adapters/session.py +++ b/src/ai/backend/manager/api/adapters/session.py @@ -854,7 +854,6 @@ async def shutdown_service( """Shut down a service in a session.""" action = ShutdownServiceAction( session_name=str(session_id), - owner_access_key=AccessKey(access_key), service_name=input.service, ) await self._processors.session.shutdown_service.wait_for_complete(action) @@ -872,7 +871,6 @@ async def get_logs( """Get container logs for a session.""" action = GetContainerLogsAction( session_name=str(session_id), - owner_access_key=AccessKey(access_key), kernel_id=KernelId(kernel_id) if kernel_id else None, ) result = await self._processors.session.get_container_logs.wait_for_complete(action) @@ -894,7 +892,6 @@ async def update( action = RenameSessionAction( session_name=str(session_id), new_name=input.name, - owner_access_key=AccessKey(access_key), ) result = await self._processors.session.rename_session.wait_for_complete(action) return UpdateSessionPayload(session=self._session_data_to_node(result.session_data)) diff --git a/src/ai/backend/manager/dependencies/agents/composer.py b/src/ai/backend/manager/dependencies/agents/composer.py index 5228e1c837b..ff88bd0afd7 100644 --- a/src/ai/backend/manager/dependencies/agents/composer.py +++ b/src/ai/backend/manager/dependencies/agents/composer.py @@ -20,6 +20,7 @@ from ai.backend.manager.registry import AgentRegistry from ai.backend.manager.repositories.deployment.repository import DeploymentRepository from ai.backend.manager.repositories.scheduler.repository import SchedulerRepository +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.deployment.deployment_controller import DeploymentController from ai.backend.manager.sokovan.deployment.revision_generator.registry import ( RevisionGeneratorRegistry, @@ -188,6 +189,7 @@ async def compose( hook_plugin_ctx=setup_input.hook_plugin_ctx, network_plugin_ctx=setup_input.network_plugin_ctx, scheduling_controller=scheduling_controller, + user_repository=UserRepository(setup_input.db), debug=setup_input.config_provider.config.debug.enabled, manager_public_key=setup_input.agent_cache.manager_public_key, manager_secret_key=setup_input.agent_cache.manager_secret_key, diff --git a/src/ai/backend/manager/dependencies/agents/registry.py b/src/ai/backend/manager/dependencies/agents/registry.py index 57ce148dfdf..ecfd9066717 100644 --- a/src/ai/backend/manager/dependencies/agents/registry.py +++ b/src/ai/backend/manager/dependencies/agents/registry.py @@ -19,6 +19,7 @@ from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.plugin.network import NetworkPluginContext from ai.backend.manager.registry import AgentRegistry +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.scheduling_controller.scheduling_controller import ( SchedulingController, ) @@ -41,6 +42,7 @@ class AgentRegistryInput: hook_plugin_ctx: HookPluginContext network_plugin_ctx: NetworkPluginContext scheduling_controller: SchedulingController + user_repository: UserRepository debug: bool manager_public_key: PublicKey manager_secret_key: SecretKey @@ -82,6 +84,7 @@ async def provide(self, setup_input: AgentRegistryInput) -> AsyncIterator[AgentR setup_input.hook_plugin_ctx, setup_input.network_plugin_ctx, setup_input.scheduling_controller, + setup_input.user_repository, debug=setup_input.debug, manager_public_key=setup_input.manager_public_key, manager_secret_key=setup_input.manager_secret_key, diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index d2840577bad..72663586670 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -254,6 +254,7 @@ def __init__( event_producer, hook_plugin_ctx, self, + user_repository, ) self._client_pool = ClientPool(tcp_client_session_factory) diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index 76174977268..c1a2a9e8ea5 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -1146,11 +1146,12 @@ async def get_terminating_sessions_by_ids( for kernel in session_row.kernels ] + owner_main_ak = session_row.user.main_access_key if session_row.user else None terminating_sessions.append( TerminatingSessionData( session_id=session_row.id, - main_access_key=AccessKey(session_row.access_key) - if session_row.access_key + main_access_key=AccessKey(owner_main_ak) + if owner_main_ak else AccessKey(""), creation_id=session_row.creation_id or "", status=session_row.status, @@ -1183,11 +1184,12 @@ async def get_pending_timeout_sessions_by_ids( sa.select( SessionRow.id, SessionRow.creation_id, - SessionRow.access_key, + UserRow.main_access_key, SessionRow.created_at, ScalingGroupRow.scheduler_opts, ) .select_from(SessionRow) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) .join(ScalingGroupRow, SessionRow.scaling_group_name == ScalingGroupRow.name) .where( SessionRow.id.in_(session_ids), @@ -1843,21 +1845,26 @@ async def allocate_sessions( # First, fetch session data to get creation_id and access_key session_ids = {alloc.session_id for alloc in allocation_batch.allocations} if session_ids: - query = sa.select( - SessionRow.id, SessionRow.creation_id, SessionRow.access_key - ).where(SessionRow.id.in_(session_ids)) + query = ( + sa.select(SessionRow.id, SessionRow.creation_id, UserRow.main_access_key) + .select_from(SessionRow) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) + .where(SessionRow.id.in_(session_ids)) + ) result = await db_sess.execute(query) - session_data_map = {row.id: (row.creation_id, row.access_key) for row in result} + session_data_map = { + row.id: (row.creation_id, row.main_access_key) for row in result + } # Create SessionEventData for each allocated session for allocation in allocation_batch.allocations: if session_data := session_data_map.get(allocation.session_id): - creation_id, access_key = session_data + creation_id, main_access_key = session_data scheduled_sessions.append( ScheduledSessionData( session_id=allocation.session_id, creation_id=creation_id, - main_access_key=access_key, + main_access_key=main_access_key, reason="triggered-by-scheduler", ) ) @@ -2917,7 +2924,9 @@ async def _get_sessions_by_statuses( scheduled_session = ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), + main_access_key=AccessKey(session.access_key) + if session.access_key + else AccessKey(""), reason="triggered-by-scheduler", ) scheduled_sessions.append(scheduled_session) diff --git a/src/ai/backend/manager/services/session/actions/commit_session.py b/src/ai/backend/manager/services/session/actions/commit_session.py index ad268fb6070..9a63a53accd 100644 --- a/src/ai/backend/manager/services/session/actions/commit_session.py +++ b/src/ai/backend/manager/services/session/actions/commit_session.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class CommitSessionAction(SessionCommitAction): session_name: str - owner_access_key: AccessKey filename: str | None @override diff --git a/src/ai/backend/manager/services/session/actions/complete.py b/src/ai/backend/manager/services/session/actions/complete.py index a67eab86e11..65c67d4df38 100644 --- a/src/ai/backend/manager/services/session/actions/complete.py +++ b/src/ai/backend/manager/services/session/actions/complete.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.dto.agent.response import CodeCompletionResp -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ @dataclass class CompleteAction(SessionAction): session_name: str - owner_access_key: AccessKey code: str # TODO: Add type options: Mapping[str, Any] | None diff --git a/src/ai/backend/manager/services/session/actions/convert_session_to_image.py b/src/ai/backend/manager/services/session/actions/convert_session_to_image.py index 9c56d624815..d6892d5ba7c 100644 --- a/src/ai/backend/manager/services/session/actions/convert_session_to_image.py +++ b/src/ai/backend/manager/services/session/actions/convert_session_to_image.py @@ -3,7 +3,6 @@ from typing import override from ai.backend.common.data.session.types import CustomizedImageVisibilityScope -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -13,7 +12,6 @@ @dataclass class ConvertSessionToImageAction(SessionAction): session_name: str - owner_access_key: AccessKey image_name: str image_visibility: CustomizedImageVisibilityScope image_owner_id: uuid.UUID diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 10e843b1df7..9756bdae519 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -30,7 +30,7 @@ class CreateClusterAction(SessionScopeAction): domain_name: str scaling_group_name: str requester_access_key: AccessKey - owner_access_key: AccessKey + owner_id: uuid.UUID tag: str enqueue_only: bool keypair_resource_policy: dict[str, Any] | None diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index 68a293cf4c9..248380ef0f3 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -30,7 +30,7 @@ class CreateFromParamsActionParams: tag: str priority: int is_preemptible: bool - owner_access_key: AccessKey + owner_id: uuid.UUID enqueue_only: bool max_wait_seconds: int starts_at: str | None diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 2b721d36d3a..9eea9834b73 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -33,7 +33,7 @@ class CreateFromTemplateActionParams: tag: str | Undefined priority: int is_preemptible: bool - owner_access_key: AccessKey | Undefined + owner_id: uuid.UUID | Undefined enqueue_only: bool max_wait_seconds: int starts_at: str | None diff --git a/src/ai/backend/manager/services/session/actions/destroy_session.py b/src/ai/backend/manager/services/session/actions/destroy_session.py index 7a8d3703165..519d5efbf54 100644 --- a/src/ai/backend/manager/services/session/actions/destroy_session.py +++ b/src/ai/backend/manager/services/session/actions/destroy_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.models.user import UserRole @@ -15,7 +14,6 @@ class DestroySessionAction(SessionAction): session_name: str forced: bool recursive: bool - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/download_file.py b/src/ai/backend/manager/services/session/actions/download_file.py index 1b030460dc1..86b97843ea3 100644 --- a/src/ai/backend/manager/services/session/actions/download_file.py +++ b/src/ai/backend/manager/services/session/actions/download_file.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ class DownloadFileAction(SessionFileAction): user_id: uuid.UUID session_name: str file: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/download_files.py b/src/ai/backend/manager/services/session/actions/download_files.py index 2a86031e1f9..fec5b246eb6 100644 --- a/src/ai/backend/manager/services/session/actions/download_files.py +++ b/src/ai/backend/manager/services/session/actions/download_files.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ class DownloadFilesAction(SessionFileAction): user_id: uuid.UUID session_name: str files: list[str] - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/execute_session.py b/src/ai/backend/manager/services/session/actions/execute_session.py index 5cd34477adc..2fb685005c8 100644 --- a/src/ai/backend/manager/services/session/actions/execute_session.py +++ b/src/ai/backend/manager/services/session/actions/execute_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -21,7 +20,6 @@ class ExecuteSessionActionParams: class ExecuteSessionAction(SessionAction): session_name: str api_version: tuple[Any, ...] - owner_access_key: AccessKey params: ExecuteSessionActionParams @override diff --git a/src/ai/backend/manager/services/session/actions/get_abusing_report.py b/src/ai/backend/manager/services/session/actions/get_abusing_report.py index c6be8be7fb2..a88273e1268 100644 --- a/src/ai/backend/manager/services/session/actions/get_abusing_report.py +++ b/src/ai/backend/manager/services/session/actions/get_abusing_report.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AbuseReport, AccessKey +from ai.backend.common.types import AbuseReport from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +12,6 @@ @dataclass class GetAbusingReportAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_commit_status.py b/src/ai/backend/manager/services/session/actions/get_commit_status.py index 6f1818b803e..beeef9669cb 100644 --- a/src/ai/backend/manager/services/session/actions/get_commit_status.py +++ b/src/ai/backend/manager/services/session/actions/get_commit_status.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetCommitStatusAction(SessionCommitAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/get_container_logs.py b/src/ai/backend/manager/services/session/actions/get_container_logs.py index fd0bbc1bdc8..33780232cbc 100644 --- a/src/ai/backend/manager/services/session/actions/get_container_logs.py +++ b/src/ai/backend/manager/services/session/actions/get_container_logs.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey, KernelId +from ai.backend.common.types import KernelId from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +12,6 @@ @dataclass class GetContainerLogsAction(SessionAction): session_name: str - owner_access_key: AccessKey kernel_id: KernelId | None @override diff --git a/src/ai/backend/manager/services/session/actions/get_dependency_graph.py b/src/ai/backend/manager/services/session/actions/get_dependency_graph.py index 67276f5f8d0..993f31cd230 100644 --- a/src/ai/backend/manager/services/session/actions/get_dependency_graph.py +++ b/src/ai/backend/manager/services/session/actions/get_dependency_graph.py @@ -2,7 +2,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetDependencyGraphAction(SessionAction): root_session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_direct_access_info.py b/src/ai/backend/manager/services/session/actions/get_direct_access_info.py index 8d34063e881..b58990d3390 100644 --- a/src/ai/backend/manager/services/session/actions/get_direct_access_info.py +++ b/src/ai/backend/manager/services/session/actions/get_direct_access_info.py @@ -2,7 +2,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetDirectAccessInfoAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_session_info.py b/src/ai/backend/manager/services/session/actions/get_session_info.py index 9e39453f5c1..3c7b197f342 100644 --- a/src/ai/backend/manager/services/session/actions/get_session_info.py +++ b/src/ai/backend/manager/services/session/actions/get_session_info.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetSessionInfoAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/get_status_history.py b/src/ai/backend/manager/services/session/actions/get_status_history.py index bdd1067d8a0..3de466918b6 100644 --- a/src/ai/backend/manager/services/session/actions/get_status_history.py +++ b/src/ai/backend/manager/services/session/actions/get_status_history.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.services.session.base import SessionAction @@ -12,7 +11,6 @@ @dataclass class GetStatusHistoryAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/interrupt_session.py b/src/ai/backend/manager/services/session/actions/interrupt_session.py index e92303e7a59..721c3893623 100644 --- a/src/ai/backend/manager/services/session/actions/interrupt_session.py +++ b/src/ai/backend/manager/services/session/actions/interrupt_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class InterruptSessionAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/list_files.py b/src/ai/backend/manager/services/session/actions/list_files.py index 0db43c50183..6ea242ff943 100644 --- a/src/ai/backend/manager/services/session/actions/list_files.py +++ b/src/ai/backend/manager/services/session/actions/list_files.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -15,7 +14,6 @@ class ListFilesAction(SessionFileAction): user_id: uuid.UUID path: str session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index ca91ae8bab5..a9c54f62360 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import RBACElementType, ScopeType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.permission.types import RBACElementRef @@ -20,7 +19,6 @@ class MatchSessionsAction(SessionScopeAction): """ id_or_name_prefix: str - owner_access_key: AccessKey user_id: uuid.UUID @override diff --git a/src/ai/backend/manager/services/session/actions/rename_session.py b/src/ai/backend/manager/services/session/actions/rename_session.py index a2b12968f52..5f616f7ba10 100644 --- a/src/ai/backend/manager/services/session/actions/rename_session.py +++ b/src/ai/backend/manager/services/session/actions/rename_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ class RenameSessionAction(SessionAction): session_name: str new_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/restart_session.py b/src/ai/backend/manager/services/session/actions/restart_session.py index 10b65b3e910..af4ac516bdb 100644 --- a/src/ai/backend/manager/services/session/actions/restart_session.py +++ b/src/ai/backend/manager/services/session/actions/restart_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class RestartSessionAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/shutdown_service.py b/src/ai/backend/manager/services/session/actions/shutdown_service.py index a07de63baf2..4a7f1810e5a 100644 --- a/src/ai/backend/manager/services/session/actions/shutdown_service.py +++ b/src/ai/backend/manager/services/session/actions/shutdown_service.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class ShutdownServiceAction(SessionAppServiceAction): session_name: str - owner_access_key: AccessKey service_name: str @override diff --git a/src/ai/backend/manager/services/session/actions/upload_files.py b/src/ai/backend/manager/services/session/actions/upload_files.py index 28f76fef60f..1935ad0c3c8 100644 --- a/src/ai/backend/manager/services/session/actions/upload_files.py +++ b/src/ai/backend/manager/services/session/actions/upload_files.py @@ -3,7 +3,6 @@ from aiohttp import MultipartReader -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -13,7 +12,6 @@ @dataclass class UploadFilesAction(SessionFileAction): session_name: str - owner_access_key: AccessKey # TODO: Refactor this. reader: MultipartReader diff --git a/src/ai/backend/manager/services/session/lifecycle.py b/src/ai/backend/manager/services/session/lifecycle.py index 1849edfe4ef..28d65740ccc 100644 --- a/src/ai/backend/manager/services/session/lifecycle.py +++ b/src/ai/backend/manager/services/session/lifecycle.py @@ -23,6 +23,7 @@ ) from ai.backend.common.plugin.hook import HookPluginContext from ai.backend.common.types import ( + AccessKey, ResourceSlot, SessionId, SessionTypes, @@ -36,6 +37,7 @@ ExtendedAsyncSAEngine, execute_with_txn_retry, ) +from ai.backend.manager.repositories.user.repository import UserRepository if TYPE_CHECKING: from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient @@ -55,6 +57,7 @@ def __init__( event_producer: EventProducer, hook_plugin_ctx: HookPluginContext, registry: AgentRegistry, + user_repository: UserRepository, ) -> None: self.db = db self.valkey_stat = valkey_stat @@ -62,6 +65,7 @@ def __init__( self.event_producer = event_producer self.hook_plugin_ctx = hook_plugin_ctx self.registry = registry + self._user_repository = user_repository def _encode(sid: SessionId) -> bytes: return sid.bytes @@ -132,14 +136,29 @@ async def _post_status_transition( await self.event_producer.anycast_event( SessionStartedAnycastEvent(session_row.id, creation_id) ) - await self.hook_plugin_ctx.notify( - "POST_START_SESSION", - ( - session_row.id, - session_row.name, - session_row.access_key, - ), + # BA-5609: resolve main_access_key from owner_id; external + # hook plugins still receive the resolved access key. If the + # owner has no main_access_key configured, skip the hook — + # calling it with ``None`` would likely break plugins that + # assume a non-null keypair identifier. + session_main_access_key = await self._user_repository.get_main_access_key_by_id( + session_row.user_uuid ) + if session_main_access_key is not None: + await self.hook_plugin_ctx.notify( + "POST_START_SESSION", + ( + session_row.id, + session_row.name, + AccessKey(session_main_access_key), + ), + ) + else: + log.warning( + "POST_START_SESSION skipped: owner {} has no main_access_key (session {})", + session_row.user_uuid, + session_row.id, + ) match session_row.session_type: case SessionTypes.BATCH: await self.registry.trigger_batch_execution(session_row) diff --git a/src/ai/backend/manager/services/session/service.py b/src/ai/backend/manager/services/session/service.py index de43d25dec2..b93fccaa8d3 100644 --- a/src/ai/backend/manager/services/session/service.py +++ b/src/ai/backend/manager/services/session/service.py @@ -19,6 +19,7 @@ from dateutil.tz import tzutc from ai.backend.common.bgtask.bgtask import BackgroundTaskManager +from ai.backend.common.contexts.user import current_user from ai.backend.common.data.session.types import CustomizedImageVisibilityScope from ai.backend.common.events.event_types.kernel.types import KernelLifecycleEventReason from ai.backend.common.events.fetcher import EventFetcher @@ -269,9 +270,39 @@ def __init__( self._rpc_ptask_group = aiotools.PersistentTaskGroup() self._webhook_ptask_group = aiotools.PersistentTaskGroup() + @staticmethod + def _requester_user_id() -> uuid.UUID: + """Return the authenticated caller's user UUID from context. + + Raises ``InternalServerError`` if no user is in context (should never + happen after the auth middleware). + """ + user = current_user() + if user is None: + raise InternalServerError("No authenticated user in request context") + return user.user_id + + async def _resolve_owner_main_access_key( + self, + owner_id: uuid.UUID, + ) -> AccessKey: + """Resolve a delegated owner UUID to that user's main access key. + + Uses the narrower ``UserRepository.get_main_access_key_by_id`` helper + so we only fetch the single scalar column we need. Raises + ``InternalServerError`` if the target user has no main access key + configured. + """ + main_access_key = await self._user_repository.get_main_access_key_by_id(owner_id) + if main_access_key is None: + raise InternalServerError( + f"Delegated owner {owner_id} has no main access key configured" + ) + return AccessKey(main_access_key) + async def commit_session(self, action: CommitSessionAction) -> CommitSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() filename = action.filename myself = asyncio.current_task() @@ -280,7 +311,7 @@ async def commit_session(self, action: CommitSessionAction) -> CommitSessionActi session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -297,13 +328,13 @@ async def commit_session(self, action: CommitSessionAction) -> CommitSessionActi async def complete(self, action: CompleteAction) -> CompleteActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() code = action.code options = action.options or {} session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -320,7 +351,7 @@ async def convert_session_to_image( self, action: ConvertSessionToImageAction ) -> ConvertSessionToImageActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() image_name = action.image_name image_visibility = action.image_visibility image_owner_id = action.image_owner_id @@ -351,7 +382,7 @@ async def convert_session_to_image( session = await self._session_repository.get_session_with_group( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -406,7 +437,7 @@ async def create_cluster(self, action: CreateClusterAction) -> CreateClusterActi sudo_session_enabled = action.sudo_session_enabled keypair_resource_policy = action.keypair_resource_policy requester_access_key = action.requester_access_key - owner_access_key = action.owner_access_key + owner_access_key = await self._resolve_owner_main_access_key(action.owner_id) domain_name = action.domain_name group_name = action.group_name scaling_group_name = action.scaling_group_name @@ -475,7 +506,7 @@ async def create_from_params( keypair_resource_policy = action.keypair_resource_policy requester_access_key = action.requester_access_key - owner_access_key = action.params.owner_access_key + owner_access_key = await self._resolve_owner_main_access_key(action.params.owner_id) domain_name = action.params.domain_name group_name = action.params.group_name config = action.params.config @@ -527,7 +558,7 @@ async def create_from_params( user_uuid=user_info.owner_uuid, user_role=user_info.owner_role.value, ), - owner_access_key, + owner_access_key if owner_access_key is not None else requester_access_key, user_info.resource_policy, session_type, config, @@ -680,7 +711,10 @@ async def create_from_template( bootstrap += cmd_builder params["bootstrap_script"] = base64.b64encode(bootstrap.encode()).decode() - owner_access_key = params["owner_access_key"] + owner_id_param = params["owner_id"] + owner_access_key: AccessKey | None = None + if owner_id_param is not None and owner_id_param is not undefined: + owner_access_key = await self._resolve_owner_main_access_key(owner_id_param) config = params["config"] cluster_size = params["cluster_size"] cluster_mode = params["cluster_mode"] @@ -710,7 +744,7 @@ async def create_from_template( keypair_resource_policy, domain_name, params["group_name"], - query_on_behalf_of=(None if owner_access_key is undefined else owner_access_key), + query_on_behalf_of=owner_access_key, ) try: @@ -731,7 +765,7 @@ async def create_from_template( user_uuid=user_info.owner_uuid, user_role=user_info.owner_role.value, ), - owner_access_key, + owner_access_key if owner_access_key is not None else requester_access_key, user_info.resource_policy, session_type, config, @@ -763,14 +797,14 @@ async def create_from_template( async def destroy_session(self, action: DestroySessionAction) -> DestroySessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() forced = action.forced recursive = action.recursive # Get session IDs to terminate (based on recursive flag) session_ids = await self._session_repository.get_target_session_ids( session_name, - owner_access_key, + owner_id, recursive=recursive, ) @@ -847,13 +881,14 @@ async def terminate_sessions_in_project( async def download_file(self, action: DownloadFileAction) -> DownloadFileActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() + owner_access_key = await self._resolve_owner_main_access_key(owner_id) user_id = action.user_id file = action.file try: session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(session) @@ -873,12 +908,12 @@ async def download_file(self, action: DownloadFileAction) -> DownloadFileActionR async def download_files(self, action: DownloadFilesAction) -> DownloadFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() user_id = action.user_id files = action.files session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -916,13 +951,13 @@ async def download_files(self, action: DownloadFilesAction) -> DownloadFilesActi async def execute_session(self, action: ExecuteSessionAction) -> ExecuteSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() api_version = action.api_version resp = {} session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -999,10 +1034,10 @@ async def get_abusing_report( self, action: GetAbusingReportAction ) -> GetAbusingReportActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) kernel = session.main_kernel @@ -1013,11 +1048,11 @@ async def get_abusing_report( async def get_commit_status(self, action: GetCommitStatusAction) -> GetCommitStatusActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) statuses = await self._agent_registry.get_commit_status([session.main_kernel.id]) @@ -1034,12 +1069,12 @@ async def get_container_logs( ) -> GetContainerLogsActionResult: resp = {"result": {"logs": ""}} session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() kernel_id = action.kernel_id compute_session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, allow_stale=True, kernel_loading_strategy=( KernelLoadingStrategy.MAIN_KERNEL_ONLY @@ -1080,10 +1115,10 @@ async def get_dependency_graph( self, action: GetDependencyGraphAction ) -> GetDependencyGraphActionResult: root_session_name = action.root_session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() dependency_graph = await self._session_repository.find_dependency_sessions( - root_session_name, owner_access_key + root_session_name, owner_id ) session_id = ( @@ -1106,11 +1141,11 @@ async def get_direct_access_info( self, action: GetDirectAccessInfoAction ) -> GetDirectAccessInfoActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() sess = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) resp = {} @@ -1143,11 +1178,11 @@ async def get_direct_access_info( async def get_session_info(self, action: GetSessionInfoAction) -> GetSessionInfoActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() sess = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(sess) @@ -1157,7 +1192,7 @@ async def get_session_info(self, action: GetSessionInfoAction) -> GetSessionInfo session_info = LegacySessionInfo( domain_name=sess.domain_name, group_id=sess.group_id, - user_id=sess.user_uuid, + user_id=sess.owner_id, lang=sess.main_kernel.image or "", # legacy image=sess.main_kernel.image or "", architecture=sess.main_kernel.architecture or "", @@ -1193,11 +1228,11 @@ async def get_status_history( self, action: GetStatusHistoryAction ) -> GetStatusHistoryActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session_row = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, ) result = session_row.status_history or {} @@ -1206,11 +1241,11 @@ async def get_status_history( async def interrupt(self, action: InterruptSessionAction) -> InterruptSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(session) @@ -1220,13 +1255,13 @@ async def interrupt(self, action: InterruptSessionAction) -> InterruptSessionAct async def list_files(self, action: ListFilesAction) -> ListFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() user_id = action.user_id path = action.path session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -1249,12 +1284,12 @@ async def list_files(self, action: ListFilesAction) -> ListFilesActionResult: async def match_sessions(self, action: MatchSessionsAction) -> MatchSessionsActionResult: id_or_name_prefix = action.id_or_name_prefix - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() matches: list[dict[str, Any]] = [] sessions = await self._session_repository.match_sessions( id_or_name_prefix, - owner_access_key, + owner_id, ) if sessions: matches.extend( @@ -1269,12 +1304,12 @@ async def match_sessions(self, action: MatchSessionsAction) -> MatchSessionsActi async def rename_session(self, action: RenameSessionAction) -> RenameSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() new_name = action.new_name try: compute_session = await self._session_repository.update_session_name( - session_name, new_name, owner_access_key + session_name, new_name, owner_id ) if compute_session.status != SessionStatus.RUNNING: raise InvalidAPIParameters("Can't change name of not running session") @@ -1287,11 +1322,11 @@ async def rename_session(self, action: RenameSessionAction) -> RenameSessionActi async def restart_session(self, action: RestartSessionAction) -> RestartSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) await self._agent_registry.increment_session_usage(session) @@ -1300,12 +1335,12 @@ async def restart_session(self, action: RestartSessionAction) -> RestartSessionA async def shutdown_service(self, action: ShutdownServiceAction) -> ShutdownServiceActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() service_name = action.service_name session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.shutdown_service(session, service_name) @@ -1321,11 +1356,12 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR envs = action.envs login_session_token = action.login_session_token + keypair = await self._user_repository.admin_get_keypair(access_key) session = await asyncio.shield( self._database_ptask_group.create_task( self._session_repository.get_session_with_routing_minimal( session_name, - access_key, + keypair.user_id, ) ) ) @@ -1399,15 +1435,16 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR "Failed to launch the app service", extra_data=result["error"] ) + resolved_access_key = await self._resolve_owner_main_access_key(session.owner_id) body = { "login_session_token": login_session_token, "kernel_host": kernel_host, "kernel_port": host_port, "session": { "id": str(session.id), - "user_uuid": str(session.user_uuid), + "user_uuid": str(session.owner_id), "group_id": str(session.group_id), - "access_key": session.access_key, + "access_key": resolved_access_key, "domain_name": session.domain_name, }, } @@ -1430,12 +1467,12 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR async def upload_files(self, action: UploadFilesAction) -> UploadFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() reader = action.reader session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) diff --git a/src/ai/backend/manager/services/session/types.py b/src/ai/backend/manager/services/session/types.py index cf806a2a1b0..929000084f8 100644 --- a/src/ai/backend/manager/services/session/types.py +++ b/src/ai/backend/manager/services/session/types.py @@ -26,7 +26,7 @@ t.Key("reuse", default=True): t.ToBool, t.Key("startup_command", default=None): t.Null | t.String, t.Key("bootstrap_script", default=None): t.Null | t.String, - t.Key("owner_access_key", default=None): t.Null | t.String, + t.Key("owner_id", default=None): t.Null | tx.UUID, tx.AliasedKey(["scaling_group", "scalingGroup"], default=None): t.Null | t.String, tx.AliasedKey(["cluster_size", "clusterSize"], default=None): t.Null | t.Int[1:], tx.AliasedKey(["cluster_mode", "clusterMode"], default="SINGLE_NODE"): tx.Enum(ClusterMode), diff --git a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py index 3ce2c514987..3625eb5a81e 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py +++ b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py @@ -370,15 +370,6 @@ async def _handle_status_transitions( batch_updaters, BulkCreator(specs=all_history_specs) ) - # Record running_at in Valkey for routes that just transitioned to RUNNING - if ( - transitions.success is not None - and transitions.success.status == RouteStatus.RUNNING - and result.successes - ): - for route in result.successes: - await self._valkey_schedule.mark_route_running_at(str(route.route_id)) - async def process_if_needed(self, lifecycle_type: RouteLifecycleType) -> None: """ Process route lifecycle operation if needed (based on internal state). diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index 8e9d5058f3c..4285a9b308c 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -193,11 +193,6 @@ async def check_running_routes(self, routes: Sequence[RouteData]) -> RouteExecut with RouteRecorderContext.shared_step("fetch_kernel_connection_info"): await self._populate_replica_info(routes_missing_replica) - # Phase 4: Ensure RouteHealthRecords exist in Valkey for routes with replica info - routes_with_replica = [r for r in successes if r.replica_host and r.replica_port] - if routes_with_replica: - await self._ensure_health_records(routes_with_replica) - return RouteExecutionResult( successes=successes, errors=errors, @@ -225,29 +220,6 @@ async def _populate_replica_info(self, routes: Sequence[RouteData]) -> None: if populated_routes: await self._initialize_health_records(populated_routes, updates) - async def _ensure_health_records(self, routes: Sequence[RouteData]) -> None: - """Ensure RouteHealthRecords exist in Valkey for routes that already have replica info. - - Routes may already have replica_host/port in DB (set by a previous cycle or legacy code) - but lack a RouteHealthRecord in Valkey. This method checks and initializes missing records. - """ - route_id_strs = [str(r.route_id) for r in routes] - existing = await self._valkey_schedule.get_route_health_records_batch(route_id_strs) - missing = [r for r in routes if existing.get(str(r.route_id)) is None] - if not missing: - return - log.warning( - "RouteHealthRecord missing in Valkey for {} routes, re-initializing: {}", - len(missing), - [str(r.route_id)[:8] for r in missing], - ) - replica_info = { - r.route_id: (r.replica_host, r.replica_port) - for r in missing - if r.replica_host and r.replica_port - } - await self._initialize_health_records(missing, replica_info) - async def _initialize_health_records( self, routes: Sequence[RouteData], @@ -258,14 +230,6 @@ async def _initialize_health_records( health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( revision_ids ) - redis_time = await self._valkey_schedule.get_redis_time() - - # Read existing running_at values that were set when routes transitioned to RUNNING - # These may be in partial hashes (only running_at field), so read raw field directly - running_at_map = await self._valkey_schedule.get_route_running_at_batch([ - str(r.route_id) for r in routes - ]) - records: list[RouteHealthRecord] = [] for route in routes: host, port = replica_info[route.route_id] @@ -274,21 +238,16 @@ async def _initialize_health_records( health_path = health_config.path if health_config else "/" initial_delay = health_config.initial_delay if health_config else 60.0 created_at = int(route.created_at.timestamp()) - - # Use running_at from Valkey (set at RUNNING transition), fallback to redis_time - route_id_str = str(route.route_id) - running_at = running_at_map.get(route_id_str) or redis_time - initial_delay_until = running_at + int(initial_delay) + initial_delay_until = created_at + int(initial_delay) records.append( RouteHealthRecord( - route_id=route_id_str, + route_id=str(route.route_id), created_at=created_at, initial_delay_until=initial_delay_until, health_path=health_path, inference_port=port, replica_host=host, - running_at=running_at, ) ) diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py index 6a8af7221d0..3f0049e1675 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py @@ -73,11 +73,6 @@ async def observe(self, routes: Sequence[RouteData]) -> RouteObservationResult: targets.append((route_id_str, record)) if not targets: - if checkable: - log.warning( - "Health observer: {} checkable routes but 0 have records in Valkey", - len(checkable), - ) return RouteObservationResult(observed_count=0) # Perform HTTP health checks in parallel diff --git a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py index f8e96ec47e7..687bd85c648 100644 --- a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py +++ b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py @@ -316,8 +316,8 @@ async def test_enqueue_session_creates_scheduling_history( id=session_id, creation_id=creation_id, name=f"test-session-{uuid.uuid4().hex[:8]}", - access_key=test_access_key, - user_uuid=test_user_uuid, + main_access_key=test_access_key, + owner_id=test_user_uuid, group_id=test_group_id, domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, @@ -354,8 +354,8 @@ async def test_enqueue_session_creates_scheduling_history( scaling_group=test_scaling_group_name, domain_name=test_domain_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, + main_access_key=test_access_key, image="python:3.8", architecture="x86_64", registry="docker.io", diff --git a/tests/unit/manager/sokovan/scheduler/conftest.py b/tests/unit/manager/sokovan/scheduler/conftest.py index 5f8e59f0c2a..5396824f2a1 100644 --- a/tests/unit/manager/sokovan/scheduler/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/conftest.py @@ -112,9 +112,9 @@ def basic_session_workload() -> SessionWorkload: """Basic SessionWorkload instance with default values.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -134,9 +134,9 @@ def batch_session_workload() -> SessionWorkload: """Batch SessionWorkload instance.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -156,9 +156,9 @@ def inference_session_workload() -> SessionWorkload: """Inference SessionWorkload instance.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -178,9 +178,9 @@ def minimal_resource_workload() -> SessionWorkload: """SessionWorkload with minimal resource requirements (1 CPU, 1 mem).""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -200,9 +200,9 @@ def small_resource_workload() -> SessionWorkload: """SessionWorkload with small resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -222,9 +222,9 @@ def medium_resource_workload() -> SessionWorkload: """SessionWorkload with medium resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -244,9 +244,9 @@ def large_resource_workload() -> SessionWorkload: """SessionWorkload with large resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(100), "mem": Decimal(100)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -272,9 +272,9 @@ def test_domain_small_resource_workload(test_domain_name: str) -> SessionWorkloa """SessionWorkload with small resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -294,9 +294,9 @@ def test_domain_medium_resource_workload(test_domain_name: str) -> SessionWorklo """SessionWorkload with medium resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -316,9 +316,9 @@ def test_domain_large_resource_workload(test_domain_name: str) -> SessionWorkloa """SessionWorkload with large resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(100), "mem": Decimal(100)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -338,9 +338,9 @@ def user1_minimal_workload() -> SessionWorkload: """Minimal workload for user1.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -366,9 +366,9 @@ def user_specific_small_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Small workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -388,9 +388,9 @@ def user_specific_medium_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Medium workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -410,9 +410,9 @@ def user_specific_minimal_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Minimal workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -433,9 +433,9 @@ def batch_session_past_start_time() -> SessionWorkload: past_time = datetime.now(tzutc()) - timedelta(hours=1) return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -456,9 +456,9 @@ def batch_session_future_start_time() -> SessionWorkload: future_time = datetime.now(tzutc()) + timedelta(hours=1) return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py index e33c6b66917..1d8c11fd5f0 100644 --- a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py @@ -137,7 +137,7 @@ def _create_terminating_session_data( return TerminatingSessionData( session_id=session_id or SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id=str(uuid4()), status=SessionStatus.TERMINATING, status_info=status_info, diff --git a/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py b/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py index d267c2d0a9b..21cb03ba5af 100644 --- a/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py +++ b/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py @@ -126,7 +126,7 @@ async def test_terminate_sessions_single_success( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="USER_REQUESTED", @@ -181,7 +181,7 @@ async def test_terminate_sessions_multiple_kernels( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="FORCED_TERMINATION", @@ -233,7 +233,7 @@ async def test_terminate_sessions_partial_failure( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="TEST_PARTIAL", @@ -310,7 +310,7 @@ async def test_terminate_sessions_concurrent_execution( sessions.append( TerminatingSessionData( session_id=session_id, - access_key=AccessKey(f"key-{i}"), + main_access_key=AccessKey(f"key-{i}"), creation_id=f"creation-{i}", status=SessionStatus.TERMINATING, status_info="BATCH_TERMINATION", @@ -364,7 +364,7 @@ async def test_terminate_sessions_empty_kernel_list( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="NO_KERNELS", diff --git a/tests/unit/manager/test_reconcile_agent_resources.py b/tests/unit/manager/test_reconcile_agent_resources.py index 14df353e04d..2fcf6e2a8db 100644 --- a/tests/unit/manager/test_reconcile_agent_resources.py +++ b/tests/unit/manager/test_reconcile_agent_resources.py @@ -104,6 +104,7 @@ async def registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), manager_public_key=PublicKey(b"GqK]ZYY#h*9jAQbGxSwkeZX3Y*%b+DiY$7ju6sh{"), manager_secret_key=SecretKey(b"37KX6]ac^&hcnSaVo=-%eVO9M]ENe8v=BOWF(Sw$"), ) @@ -498,6 +499,7 @@ async def registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), manager_public_key=PublicKey(b"GqK]ZYY#h*9jAQbGxSwkeZX3Y*%b+DiY$7ju6sh{"), manager_secret_key=SecretKey(b"37KX6]ac^&hcnSaVo=-%eVO9M]ENe8v=BOWF(Sw$"), )