Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11045.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rename `SessionData.user_uuid` / `SessionMetadata.user_uuid` to `owner_id` and drop the redundant `access_key` snapshot fields from those data types. `ComputeSessionNode.access_key` is now sourced from the owner's `main_access_key`, kept in step with the underlying user record.
4 changes: 2 additions & 2 deletions src/ai/backend/manager/api/adapters/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,13 +933,13 @@ def _session_data_to_node(data: SessionData) -> SessionNode:
return SessionNode(
id=data.id,
domain_name=data.domain_name,
user_id=data.user_uuid,
user_id=data.owner_id,
project_id=data.group_id,
metadata=SessionMetadataInfoGQLDTO(
creation_id=data.creation_id or "",
name=data.name or "",
session_type=data.session_type.value,
access_key=str(data.access_key) if data.access_key else "",
access_key="",
cluster_mode=data.cluster_mode.name,
cluster_size=data.cluster_size,
priority=data.priority,
Expand Down
37 changes: 28 additions & 9 deletions src/ai/backend/manager/api/gql_legacy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from ai.backend.common import validators as tx
from ai.backend.common.defs.session import SESSION_PRIORITY_MAX, SESSION_PRIORITY_MIN
from ai.backend.common.exception import SessionWithInvalidStateError
from ai.backend.common.types import (
ClusterMode,
KernelId,
Expand Down Expand Up @@ -395,19 +394,25 @@ def from_dataclass(
cls,
ctx: GraphQueryContext,
session_data: SessionData,
main_access_key: str | None,
*,
permissions: Iterable[ComputeSessionPermission] | None = None,
) -> Self:
"""Build a ``ComputeSessionNode`` from session data.

``main_access_key`` must be pre-resolved by the caller (typically
via ``UserRepository.get_main_access_key_by_id(session_data.owner_id)``
or by eagerly loading ``session_data.owner``). Keeping the helper
synchronous avoids a hidden per-session DB query and lets the
caller batch the lookup across nodes.
"""
status_history = session_data.status_history or {}
raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name)
if not session_data.vfolder_mounts:
vfolder_mounts = []
else:
vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts]

if session_data.owner is None:
raise SessionWithInvalidStateError()

result = cls(
# identity
id=session_data.id, # auto-converted to Relay global ID
Expand All @@ -422,9 +427,9 @@ def from_dataclass(
# ownership
domain_name=session_data.domain_name,
project_id=session_data.group_id,
user_id=session_data.user_uuid,
access_key=session_data.access_key,
owner=UserNode.from_dataclass(ctx, session_data.owner),
user_id=session_data.owner_id,
access_key=main_access_key,
owner=UserNode.from_dataclass(ctx, session_data.owner) if session_data.owner else None,
# status
status=session_data.status.name,
# status_changed=row.status_changed, # FIXME: generated attribute
Expand Down Expand Up @@ -918,8 +923,14 @@ async def mutate_and_get_payload(
)
)

session_data = result.session_data
main_access_key = (
session_data.owner.main_access_key
if session_data.owner
else await graph_ctx.user_repository.get_main_access_key_by_id(session_data.owner_id)
)
return ModifyComputeSession(
ComputeSessionNode.from_dataclass(graph_ctx, result.session_data),
ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key),
input.get("client_mutation_id"),
)

Expand Down Expand Up @@ -969,8 +980,16 @@ async def mutate(
)
)
)
session_data = action_result.session_data
main_access_key = (
session_data.owner.main_access_key
if session_data.owner
else await graph_ctx.user_repository.get_main_access_key_by_id(
session_data.owner_id
)
)
session_nodes.append(
ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data)
ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key)
)

return CheckAndTransitStatus(session_nodes, input.get("client_mutation_id"))
Expand Down
7 changes: 2 additions & 5 deletions src/ai/backend/manager/data/session/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from ai.backend.common.data.vfolder.types import VFolderMountData
from ai.backend.common.types import (
AccessKey,
CIStrEnum,
ClusterMode,
ResourceSlot,
Expand Down Expand Up @@ -155,7 +154,7 @@ class SessionData:
cluster_size: int
domain_name: str
group_id: UUID
user_uuid: UUID
owner_id: UUID
occupying_slots: Any # TODO: ResourceSlot?
requested_slots: Any
use_host_network: bool
Expand All @@ -165,7 +164,6 @@ class SessionData:
num_queries: int
creation_id: str | None
name: str | None
access_key: AccessKey | None
agent_ids: list[str] | None
images: list[str] | None
tag: str | None
Expand Down Expand Up @@ -206,8 +204,7 @@ class SessionMetadata:
name: str
domain_name: str
group_id: UUID
user_uuid: UUID
access_key: str
owner_id: UUID
session_type: SessionTypes
priority: int
created_at: datetime | None
Expand Down
10 changes: 3 additions & 7 deletions src/ai/backend/manager/models/endpoint/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,12 @@ class EndpointRow(Base): # type: ignore[misc]

__table_args__ = (
sa.Index(
"ix_endpoints_unique_name_when_active",
"ix_endpoints_unique_name_when_not_destroyed",
"name",
"domain",
"project",
unique=True,
postgresql_where=sa.column("lifecycle_stage").notin_([
EndpointLifecycle.DESTROYING.value,
EndpointLifecycle.DESTROYED.value,
]),
postgresql_where=(sa.column("lifecycle_stage") != EndpointLifecycle.DESTROYED.value),
),
sa.Index(
"ix_endpoints_lifecycle_sub_step",
Expand Down Expand Up @@ -530,7 +527,6 @@ async def delegate_endpoint_ownership(
db_session: AsyncSession,
owner_user_uuid: UUID,
target_user_uuid: UUID,
target_access_key: AccessKey,
) -> None:
from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow

Expand All @@ -554,7 +550,7 @@ async def delegate_endpoint_ownership(
db_session, session_ids, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS
)
for session_row in session_rows:
session_row.delegate_ownership(target_user_uuid, target_access_key)
session_row.delegate_ownership(target_user_uuid)

async def generate_route_info(
self, db_sess: AsyncSession
Expand Down
16 changes: 1 addition & 15 deletions src/ai/backend/manager/models/keypair/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.hashes import SHA256
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql.expression import false

from ai.backend.common import msgpack
Expand All @@ -31,7 +31,6 @@
if TYPE_CHECKING:
from ai.backend.manager.models.resource_policy import KeyPairResourcePolicyRow
from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.models.user import UserRow

__all__: Sequence[str] = (
Expand All @@ -48,13 +47,6 @@
MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB


# Defined for avoiding circular import
def _get_session_row_join_condition() -> sa.ColumnElement[bool]:
from ai.backend.manager.models.session import SessionRow

return KeyPairRow.access_key == foreign(SessionRow.access_key)


class KeyPairRow(Base): # type: ignore[misc]
__tablename__ = "keypairs"

Expand Down Expand Up @@ -100,12 +92,6 @@ class KeyPairRow(Base): # type: ignore[misc]
)

# Relationships
sessions: Mapped[list[SessionRow]] = relationship(
"SessionRow",
primaryjoin=_get_session_row_join_condition,
foreign_keys="SessionRow.access_key",
back_populates="access_key_row",
)
resource_policy_row: Mapped[KeyPairResourcePolicyRow] = relationship(
"KeyPairResourcePolicyRow", back_populates="keypairs"
)
Expand Down
18 changes: 15 additions & 3 deletions src/ai/backend/manager/models/resource_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,12 @@ async def parse_resource_usage_groups(
last_stat=stat_map.get(kern.id),
user_id=kern.session.user_uuid,
user_email=kern.session.user.email if kern.session.user is not None else None,
access_key=kern.session.access_key,
# The old ``SessionRow.access_key`` column is being dropped in a
# later slice; source the keypair access_key from the owner's
# ``main_access_key`` instead.
access_key=(
kern.session.user.main_access_key if kern.session.user is not None else None
Comment thread
jopemachine marked this conversation as resolved.
),
project_id=kern.session.group.id if kern.session.group is not None else None,
project_name=kern.session.group.name if kern.session.group is not None else None,
kernel_id=kern.id,
Expand Down Expand Up @@ -553,7 +558,9 @@ async def parse_resource_usage_groups(
SessionRow.domain_name,
SessionRow.id,
SessionRow.group_id,
SessionRow.access_key,
# SessionRow.access_key is deprecated (removed in a later slice); callers
# that need the keypair access_key should join UserRow and read
# users.main_access_key instead.
SessionRow.images,
SessionRow.cluster_mode,
SessionRow.status_history,
Expand Down Expand Up @@ -606,7 +613,12 @@ def _parse_query(
session_load.options(
load_only(*SESSION_RESOURCE_SELECT_COLS),
joinedload(SessionRow.user).options(
load_only(UserRow.email, UserRow.username, UserRow.full_name)
load_only(
UserRow.email,
UserRow.username,
UserRow.full_name,
UserRow.main_access_key,
)
),
project_load.options(load_only(*PROJECT_RESOURCE_SELECT_COLS)),
),
Expand Down
Loading
Loading