Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11049.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Drop `owner_id` from read/control session action dataclasses and resolve the requester via `current_user()` in the service layer. `SessionLifecycleManager` receives `UserRepository` via constructor injection.
1 change: 0 additions & 1 deletion changes/BA-5650-D.misc.md

This file was deleted.

3 changes: 0 additions & 3 deletions src/ai/backend/manager/api/adapters/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,6 @@ async def shutdown_service(
"""Shut down a service in a session."""
action = ShutdownServiceAction(
session_name=str(session_id),
owner_access_key=AccessKey(access_key),
service_name=input.service,
)
await self._processors.session.shutdown_service.wait_for_complete(action)
Expand All @@ -872,7 +871,6 @@ async def get_logs(
"""Get container logs for a session."""
action = GetContainerLogsAction(
session_name=str(session_id),
owner_access_key=AccessKey(access_key),
kernel_id=KernelId(kernel_id) if kernel_id else None,
)
result = await self._processors.session.get_container_logs.wait_for_complete(action)
Expand All @@ -894,7 +892,6 @@ async def update(
action = RenameSessionAction(
session_name=str(session_id),
new_name=input.name,
owner_access_key=AccessKey(access_key),
)
result = await self._processors.session.rename_session.wait_for_complete(action)
return UpdateSessionPayload(session=self._session_data_to_node(result.session_data))
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/dependencies/agents/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ai.backend.manager.registry import AgentRegistry
from ai.backend.manager.repositories.deployment.repository import DeploymentRepository
from ai.backend.manager.repositories.scheduler.repository import SchedulerRepository
from ai.backend.manager.repositories.user.repository import UserRepository
from ai.backend.manager.sokovan.deployment.deployment_controller import DeploymentController
from ai.backend.manager.sokovan.deployment.revision_generator.registry import (
RevisionGeneratorRegistry,
Expand Down Expand Up @@ -188,6 +189,7 @@ async def compose(
hook_plugin_ctx=setup_input.hook_plugin_ctx,
network_plugin_ctx=setup_input.network_plugin_ctx,
scheduling_controller=scheduling_controller,
user_repository=UserRepository(setup_input.db),
debug=setup_input.config_provider.config.debug.enabled,
manager_public_key=setup_input.agent_cache.manager_public_key,
manager_secret_key=setup_input.agent_cache.manager_secret_key,
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/manager/dependencies/agents/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.plugin.network import NetworkPluginContext
from ai.backend.manager.registry import AgentRegistry
from ai.backend.manager.repositories.user.repository import UserRepository
from ai.backend.manager.sokovan.scheduling_controller.scheduling_controller import (
SchedulingController,
)
Expand All @@ -41,6 +42,7 @@ class AgentRegistryInput:
hook_plugin_ctx: HookPluginContext
network_plugin_ctx: NetworkPluginContext
scheduling_controller: SchedulingController
user_repository: UserRepository
debug: bool
manager_public_key: PublicKey
manager_secret_key: SecretKey
Expand Down Expand Up @@ -82,6 +84,7 @@ async def provide(self, setup_input: AgentRegistryInput) -> AsyncIterator[AgentR
setup_input.hook_plugin_ctx,
setup_input.network_plugin_ctx,
setup_input.scheduling_controller,
setup_input.user_repository,
debug=setup_input.debug,
manager_public_key=setup_input.manager_public_key,
manager_secret_key=setup_input.manager_secret_key,
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
event_producer,
hook_plugin_ctx,
self,
user_repository,
)
self._client_pool = ClientPool(tcp_client_session_factory)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1146,11 +1146,12 @@ async def get_terminating_sessions_by_ids(
for kernel in session_row.kernels
]

owner_main_ak = session_row.user.main_access_key if session_row.user else None
terminating_sessions.append(
TerminatingSessionData(
session_id=session_row.id,
main_access_key=AccessKey(session_row.access_key)
if session_row.access_key
main_access_key=AccessKey(owner_main_ak)
if owner_main_ak
else AccessKey(""),
creation_id=session_row.creation_id or "",
status=session_row.status,
Expand Down Expand Up @@ -1183,11 +1184,12 @@ async def get_pending_timeout_sessions_by_ids(
sa.select(
SessionRow.id,
SessionRow.creation_id,
SessionRow.access_key,
UserRow.main_access_key,
SessionRow.created_at,
ScalingGroupRow.scheduler_opts,
)
.select_from(SessionRow)
.join(UserRow, SessionRow.user_uuid == UserRow.uuid)
.join(ScalingGroupRow, SessionRow.scaling_group_name == ScalingGroupRow.name)
.where(
SessionRow.id.in_(session_ids),
Expand Down Expand Up @@ -1843,21 +1845,26 @@ async def allocate_sessions(
# First, fetch session data to get creation_id and access_key
session_ids = {alloc.session_id for alloc in allocation_batch.allocations}
if session_ids:
query = sa.select(
SessionRow.id, SessionRow.creation_id, SessionRow.access_key
).where(SessionRow.id.in_(session_ids))
query = (
sa.select(SessionRow.id, SessionRow.creation_id, UserRow.main_access_key)
.select_from(SessionRow)
.join(UserRow, SessionRow.user_uuid == UserRow.uuid)
.where(SessionRow.id.in_(session_ids))
)
result = await db_sess.execute(query)
session_data_map = {row.id: (row.creation_id, row.access_key) for row in result}
session_data_map = {
row.id: (row.creation_id, row.main_access_key) for row in result
}

# Create SessionEventData for each allocated session
for allocation in allocation_batch.allocations:
if session_data := session_data_map.get(allocation.session_id):
creation_id, access_key = session_data
creation_id, main_access_key = session_data
scheduled_sessions.append(
ScheduledSessionData(
session_id=allocation.session_id,
creation_id=creation_id,
main_access_key=access_key,
main_access_key=main_access_key,
reason="triggered-by-scheduler",
)
)
Expand Down Expand Up @@ -2917,7 +2924,9 @@ async def _get_sessions_by_statuses(
scheduled_session = ScheduledSessionData(
session_id=session.id,
creation_id=session.creation_id or "",
main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""),
main_access_key=AccessKey(session.access_key)
if session.access_key
else AccessKey(""),
reason="triggered-by-scheduler",
)
scheduled_sessions.append(scheduled_session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from typing import Any, override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +11,6 @@
@dataclass
class CommitSessionAction(SessionCommitAction):
session_name: str
owner_access_key: AccessKey
filename: str | None

@override
Expand Down
2 changes: 0 additions & 2 deletions src/ai/backend/manager/services/session/actions/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, override

from ai.backend.common.dto.agent.response import CodeCompletionResp
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -14,7 +13,6 @@
@dataclass
class CompleteAction(SessionAction):
session_name: str
owner_access_key: AccessKey
code: str
# TODO: Add type
options: Mapping[str, Any] | None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import override

from ai.backend.common.data.session.types import CustomizedImageVisibilityScope
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -13,7 +12,6 @@
@dataclass
class ConvertSessionToImageAction(SessionAction):
session_name: str
owner_access_key: AccessKey
image_name: str
image_visibility: CustomizedImageVisibilityScope
image_owner_id: uuid.UUID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CreateClusterAction(SessionScopeAction):
domain_name: str
scaling_group_name: str
requester_access_key: AccessKey
owner_access_key: AccessKey
owner_id: uuid.UUID
tag: str
enqueue_only: bool
keypair_resource_policy: dict[str, Any] | None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CreateFromParamsActionParams:
tag: str
priority: int
is_preemptible: bool
owner_access_key: AccessKey
owner_id: uuid.UUID
enqueue_only: bool
max_wait_seconds: int
starts_at: str | None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CreateFromTemplateActionParams:
tag: str | Undefined
priority: int
is_preemptible: bool
owner_access_key: AccessKey | Undefined
owner_id: uuid.UUID | Undefined
enqueue_only: bool
max_wait_seconds: int
starts_at: str | None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Any, override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.models.user import UserRole
Expand All @@ -15,7 +14,6 @@ class DestroySessionAction(SessionAction):
session_name: str
forced: bool
recursive: bool
owner_access_key: AccessKey

@override
def entity_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -14,7 +13,6 @@ class DownloadFileAction(SessionFileAction):
user_id: uuid.UUID
session_name: str
file: str
owner_access_key: AccessKey

@override
def entity_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from typing import Any, override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -14,7 +13,6 @@ class DownloadFilesAction(SessionFileAction):
user_id: uuid.UUID
session_name: str
files: list[str]
owner_access_key: AccessKey

@override
def entity_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Any, override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -21,7 +20,6 @@ class ExecuteSessionActionParams:
class ExecuteSessionAction(SessionAction):
session_name: str
api_version: tuple[Any, ...]
owner_access_key: AccessKey
params: ExecuteSessionActionParams

@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import override

from ai.backend.common.data.permission.types import EntityType
from ai.backend.common.types import AbuseReport, AccessKey
from ai.backend.common.types import AbuseReport
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +12,6 @@
@dataclass
class GetAbusingReportAction(SessionAction):
session_name: str
owner_access_key: AccessKey

@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +11,6 @@
@dataclass
class GetCommitStatusAction(SessionCommitAction):
session_name: str
owner_access_key: AccessKey

@override
def entity_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import override

from ai.backend.common.data.permission.types import EntityType
from ai.backend.common.types import AccessKey, KernelId
from ai.backend.common.types import KernelId
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +12,6 @@
@dataclass
class GetContainerLogsAction(SessionAction):
session_name: str
owner_access_key: AccessKey
kernel_id: KernelId | None

@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, override

from ai.backend.common.data.permission.types import EntityType
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +11,6 @@
@dataclass
class GetDependencyGraphAction(SessionAction):
root_session_name: str
owner_access_key: AccessKey

@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, override

from ai.backend.common.data.permission.types import EntityType
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +11,6 @@
@dataclass
class GetDirectAccessInfoAction(SessionAction):
session_name: str
owner_access_key: AccessKey

@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.session.types import SessionData
Expand All @@ -12,7 +11,6 @@
@dataclass
class GetSessionInfoAction(SessionAction):
session_name: str
owner_access_key: AccessKey

@override
def entity_id(self) -> str | None:
Expand Down
Loading
Loading