Skip to content

Commit 017b410

Browse files
committed
refactor(BA-5650-C): rename SessionData user_uuid to owner_id
Data-layer rename of SessionData / SessionMetadata user_uuid to owner_id, plus the matching SessionRow adapters (to_dataclass/from_dataclass/to_session_info/from_session_info). ComputeSessionNode.from_dataclass becomes async and resolves main_access_key from the owner via UserRepository when owner is not eagerly loaded. models/resource_usage.py sources user_id from SessionRow.user_uuid through the session relationship. SessionRow._build_session_fetch_query / _match_sessions_by_* also rename the filter parameter to owner_id; repository and db_source callers are updated in the next slice so intermediate builds may fail in isolation.
1 parent d88a14b commit 017b410

5 files changed

Lines changed: 49 additions & 74 deletions

File tree

changes/BA-5650-C.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Rename `SessionData.user_uuid` / `SessionMetadata.user_uuid` to `owner_id` and drop the redundant `access_key` snapshot fields. `SessionRow.to_dataclass` / `from_dataclass`, `ComputeSessionNode.from_dataclass`, and `resource_usage` queries now resolve the owner's `main_access_key` from the eager-loaded `user` relationship.

src/ai/backend/manager/api/gql_legacy/session.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from ai.backend.common import validators as tx
2525
from ai.backend.common.defs.session import SESSION_PRIORITY_MAX, SESSION_PRIORITY_MIN
26-
from ai.backend.common.exception import SessionWithInvalidStateError
2726
from ai.backend.common.types import (
2827
ClusterMode,
2928
KernelId,
@@ -391,7 +390,7 @@ def from_row(
391390
return result
392391

393392
@classmethod
394-
def from_dataclass(
393+
async def from_dataclass(
395394
cls,
396395
ctx: GraphQueryContext,
397396
session_data: SessionData,
@@ -405,8 +404,16 @@ def from_dataclass(
405404
else:
406405
vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts]
407406

408-
if session_data.owner is None:
409-
raise SessionWithInvalidStateError()
407+
# access_key on this node is the owner's main_access_key. When the
408+
# caller hasn't eagerly loaded session_data.owner, fall back to the
409+
# user repository so the required field is always populated from the
410+
# always-present owner_id.
411+
if session_data.owner is not None:
412+
main_access_key = session_data.owner.main_access_key
413+
else:
414+
main_access_key = await ctx.user_repository.get_main_access_key_by_id(
415+
session_data.owner_id
416+
)
410417

411418
result = cls(
412419
# identity
@@ -422,9 +429,9 @@ def from_dataclass(
422429
# ownership
423430
domain_name=session_data.domain_name,
424431
project_id=session_data.group_id,
425-
user_id=session_data.user_uuid,
426-
access_key=session_data.access_key,
427-
owner=UserNode.from_dataclass(ctx, session_data.owner),
432+
user_id=session_data.owner_id,
433+
access_key=main_access_key,
434+
owner=UserNode.from_dataclass(ctx, session_data.owner) if session_data.owner else None,
428435
# status
429436
status=session_data.status.name,
430437
# status_changed=row.status_changed, # FIXME: generated attribute
@@ -919,7 +926,7 @@ async def mutate_and_get_payload(
919926
)
920927

921928
return ModifyComputeSession(
922-
ComputeSessionNode.from_dataclass(graph_ctx, result.session_data),
929+
await ComputeSessionNode.from_dataclass(graph_ctx, result.session_data),
923930
input.get("client_mutation_id"),
924931
)
925932

@@ -970,7 +977,7 @@ async def mutate(
970977
)
971978
)
972979
session_nodes.append(
973-
ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data)
980+
await ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data)
974981
)
975982

976983
return CheckAndTransitStatus(session_nodes, input.get("client_mutation_id"))

src/ai/backend/manager/data/session/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from ai.backend.common.data.vfolder.types import VFolderMountData
1414
from ai.backend.common.types import (
15-
AccessKey,
1615
CIStrEnum,
1716
ClusterMode,
1817
ResourceSlot,
@@ -155,7 +154,7 @@ class SessionData:
155154
cluster_size: int
156155
domain_name: str
157156
group_id: UUID
158-
user_uuid: UUID
157+
owner_id: UUID
159158
occupying_slots: Any # TODO: ResourceSlot?
160159
requested_slots: Any
161160
use_host_network: bool
@@ -165,7 +164,6 @@ class SessionData:
165164
num_queries: int
166165
creation_id: str | None
167166
name: str | None
168-
access_key: AccessKey | None
169167
agent_ids: list[str] | None
170168
images: list[str] | None
171169
tag: str | None
@@ -206,8 +204,7 @@ class SessionMetadata:
206204
name: str
207205
domain_name: str
208206
group_id: UUID
209-
user_uuid: UUID
210-
access_key: str
207+
owner_id: UUID
211208
session_type: SessionTypes
212209
priority: int
213210
created_at: datetime | None

src/ai/backend/manager/models/resource_usage.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,11 @@ async def parse_resource_usage_groups(
524524
last_stat=stat_map.get(kern.id),
525525
user_id=kern.session.user_uuid,
526526
user_email=kern.session.user.email if kern.session.user is not None else None,
527-
access_key=kern.session.access_key,
527+
# TODO(BA-5609 phase D): resolve access_key from owner via
528+
# users.main_access_key. SessionRow.access_key has been removed.
529+
access_key=(
530+
kern.session.user.main_access_key if kern.session.user is not None else None
531+
),
528532
project_id=kern.session.group.id if kern.session.group is not None else None,
529533
project_name=kern.session.group.name if kern.session.group is not None else None,
530534
kernel_id=kern.id,
@@ -553,7 +557,8 @@ async def parse_resource_usage_groups(
553557
SessionRow.domain_name,
554558
SessionRow.id,
555559
SessionRow.group_id,
556-
SessionRow.access_key,
560+
# TODO(BA-5609 phase D): SessionRow.access_key removed. Callers should
561+
# join UserRow and read users.main_access_key when an access_key is needed.
557562
SessionRow.images,
558563
SessionRow.cluster_mode,
559564
SessionRow.status_history,

src/ai/backend/manager/models/session/row.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@
123123

124124
if TYPE_CHECKING:
125125
from ai.backend.manager.models.domain import DomainRow
126-
from ai.backend.manager.models.keypair import KeyPairRow
127126
from ai.backend.manager.models.scaling_group import ScalingGroupRow
128127
from ai.backend.manager.models.user import UserRow
129128

@@ -494,17 +493,17 @@ async def handle_session_exception(
494493

495494
def _build_session_fetch_query(
496495
base_cond: Any,
497-
access_key: AccessKey | None = None,
498496
*,
497+
owner_id: UUID | None = None,
499498
allow_stale: bool = True,
500499
for_update: bool = False,
501500
do_ordering: bool = False,
502501
max_matches: int | None = None,
503502
eager_loading_op: Sequence[_AbstractLoad] | None = None,
504503
) -> sa.sql.Select[Any]:
505504
cond = base_cond
506-
if access_key:
507-
cond = cond & (SessionRow.access_key == access_key)
505+
if owner_id is not None:
506+
cond = cond & (SessionRow.user_uuid == owner_id)
508507
if not allow_stale:
509508
cond = cond & (~SessionRow.status.in_(DEAD_SESSION_STATUSES))
510509
query = (
@@ -528,8 +527,8 @@ def _build_session_fetch_query(
528527
async def _match_sessions_by_id(
529528
db_session: SASession,
530529
session_id_or_list: SessionId | list[SessionId],
531-
access_key: AccessKey | None = None,
532530
*,
531+
owner_id: UUID | None = None,
533532
allow_prefix: bool = False,
534533
allow_stale: bool = True,
535534
for_update: bool = False,
@@ -546,7 +545,7 @@ async def _match_sessions_by_id(
546545
cond = SessionRow.id == session_id_or_list
547546
query = _build_session_fetch_query(
548547
cond,
549-
access_key,
548+
owner_id=owner_id,
550549
max_matches=max_matches,
551550
allow_stale=allow_stale,
552551
for_update=for_update,
@@ -560,8 +559,8 @@ async def _match_sessions_by_id(
560559
async def _match_sessions_by_name(
561560
db_session: SASession,
562561
session_name: str,
563-
access_key: AccessKey,
564562
*,
563+
owner_id: UUID | None = None,
565564
allow_prefix: bool = False,
566565
allow_stale: bool = True,
567566
for_update: bool = False,
@@ -575,7 +574,7 @@ async def _match_sessions_by_name(
575574
cond = SessionRow.name == session_name
576575
query = _build_session_fetch_query(
577576
cond,
578-
access_key,
577+
owner_id=owner_id,
579578
max_matches=max_matches,
580579
allow_stale=allow_stale,
581580
for_update=for_update,
@@ -595,20 +594,6 @@ class ConcurrencyUsed:
595594
compute_session_ids: set[SessionId] = field(default_factory=set)
596595
system_session_ids: set[SessionId] = field(default_factory=set)
597596

598-
@property
599-
def compute_concurrency_used_key(self) -> str:
600-
return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}"
601-
602-
@property
603-
def system_concurrency_used_key(self) -> str:
604-
return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}"
605-
606-
def to_cnt_map(self) -> Mapping[str, int]:
607-
return {
608-
self.compute_concurrency_used_key: len(self.compute_session_ids),
609-
self.system_concurrency_used_key: len(self.system_session_ids),
610-
}
611-
612597

613598
class SessionOp(enum.StrEnum):
614599
CREATE = "create_session"
@@ -637,13 +622,6 @@ class KernelLoadingStrategy(enum.StrEnum):
637622
}
638623

639624

640-
# Defined for avoiding circular import
641-
def _get_keypair_row_join_condition() -> sa.sql.elements.ColumnElement[Any]:
642-
from ai.backend.manager.models.keypair import KeyPairRow
643-
644-
return KeyPairRow.access_key == foreign(SessionRow.access_key)
645-
646-
647625
def _get_user_row_join_condition() -> sa.sql.elements.ColumnElement[Any]:
648626
from ai.backend.manager.models.user import UserRow
649627

@@ -731,14 +709,7 @@ class SessionRow(Base): # type: ignore[misc]
731709
back_populates="sessions",
732710
foreign_keys=[user_uuid],
733711
)
734-
735712
access_key: Mapped[str | None] = mapped_column("access_key", sa.String(length=20))
736-
access_key_row: Mapped[KeyPairRow | None] = relationship(
737-
"KeyPairRow",
738-
primaryjoin=_get_keypair_row_join_condition,
739-
back_populates="sessions",
740-
foreign_keys=[access_key],
741-
)
742713

743714
# `image` column is identical to kernels `image` column.
744715
images: Mapped[list[str] | None] = mapped_column("images", sa.ARRAY(sa.String), nullable=True)
@@ -884,7 +855,7 @@ class SessionRow(Base): # type: ignore[misc]
884855
sa.Index("ix_session_status_with_priority", "status", "priority"),
885856
# Unique index for session names per user excluding terminal statuses
886857
sa.Index(
887-
"ix_sessions_unique_name_per_user_nonterminal",
858+
"ix_sessions_unique_name_per_owner_nonterminal",
888859
"name",
889860
"user_uuid",
890861
unique=True,
@@ -923,8 +894,7 @@ def from_dataclass(cls, session_data: SessionData) -> SessionRow:
923894
target_sgroup_names=session_data.target_sgroup_names,
924895
domain_name=session_data.domain_name,
925896
group_id=session_data.group_id,
926-
user_uuid=session_data.user_uuid,
927-
access_key=session_data.access_key,
897+
user_uuid=session_data.owner_id,
928898
images=session_data.images,
929899
tag=session_data.tag,
930900
occupying_slots=session_data.occupying_slots,
@@ -968,8 +938,7 @@ def to_dataclass(self, owner: UserData | None = None) -> SessionData:
968938
target_sgroup_names=self.target_sgroup_names,
969939
domain_name=self.domain_name,
970940
group_id=self.group_id,
971-
user_uuid=self.user_uuid,
972-
access_key=AccessKey(self.access_key) if self.access_key else None,
941+
owner_id=self.user_uuid,
973942
images=self.images,
974943
tag=self.tag,
975944
occupying_slots=self.occupying_slots,
@@ -1017,8 +986,7 @@ def from_session_info(cls, info: SessionInfo) -> Self:
1017986
target_sgroup_names=info.resource.target_sgroup_names,
1018987
domain_name=info.metadata.domain_name,
1019988
group_id=info.metadata.group_id,
1020-
user_uuid=info.metadata.user_uuid,
1021-
access_key=info.metadata.access_key,
989+
user_uuid=info.metadata.owner_id,
1022990
images=info.image.images,
1023991
tag=info.image.tag or info.metadata.tag,
1024992
occupying_slots=info.resource.occupying_slots,
@@ -1059,8 +1027,7 @@ def to_session_info(self) -> SessionInfo:
10591027
name=self.name or "",
10601028
domain_name=self.domain_name,
10611029
group_id=self.group_id,
1062-
user_uuid=self.user_uuid,
1063-
access_key=self.access_key or "",
1030+
owner_id=self.user_uuid,
10641031
session_type=self.session_type,
10651032
priority=self.priority,
10661033
created_at=self.created_at,
@@ -1284,11 +1251,10 @@ def set_status(
12841251
if _status_info is not None:
12851252
self.status_info = _status_info
12861253

1287-
def delegate_ownership(self, user_uuid: UUID, access_key: AccessKey) -> None:
1288-
self.user_uuid = user_uuid
1289-
self.access_key = access_key
1254+
def delegate_ownership(self, owner_id: UUID) -> None:
1255+
self.user_uuid = owner_id
12901256
for kernel_row in self.kernels:
1291-
kernel_row.delegate_ownership(user_uuid, access_key)
1257+
kernel_row.delegate_ownership(owner_id)
12921258

12931259
@staticmethod
12941260
async def delete_by_user_id(user_uuid: UUID, *, db_session: SASession) -> None:
@@ -1357,8 +1323,8 @@ async def match_sessions(
13571323
cls,
13581324
db_session: SASession,
13591325
session_reference: str | UUID | list[UUID],
1360-
access_key: AccessKey | None,
13611326
*,
1327+
owner_id: UUID | None = None,
13621328
allow_prefix: bool = False,
13631329
allow_stale: bool = True,
13641330
for_update: bool = False,
@@ -1412,7 +1378,7 @@ async def match_sessions(
14121378
for fetch_func in query_list:
14131379
rows = await fetch_func(
14141380
db_session,
1415-
access_key=access_key,
1381+
owner_id=owner_id,
14161382
allow_stale=allow_stale,
14171383
for_update=for_update,
14181384
max_matches=max_matches,
@@ -1428,8 +1394,8 @@ async def get_session(
14281394
cls,
14291395
db_session: SASession,
14301396
session_name_or_id: str | UUID,
1431-
access_key: AccessKey | None = None,
14321397
*,
1398+
owner_id: UUID | None = None,
14331399
allow_stale: bool = False,
14341400
for_update: bool = False,
14351401
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE,
@@ -1474,7 +1440,7 @@ async def get_session(
14741440
session_list = await cls.match_sessions(
14751441
db_session,
14761442
session_name_or_id,
1477-
access_key,
1443+
owner_id=owner_id,
14781444
allow_stale=allow_stale,
14791445
for_update=for_update,
14801446
eager_loading_op=_eager_loading_op,
@@ -1499,8 +1465,8 @@ async def list_sessions(
14991465
cls,
15001466
db_session: SASession,
15011467
session_ids: list[UUID],
1502-
access_key: AccessKey | None = None,
15031468
*,
1469+
owner_id: UUID | None = None,
15041470
allow_stale: bool = False,
15051471
for_update: bool = False,
15061472
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE,
@@ -1531,7 +1497,7 @@ async def list_sessions(
15311497
session_list = await cls.match_sessions(
15321498
db_session,
15331499
session_ids,
1534-
access_key,
1500+
owner_id=owner_id,
15351501
allow_stale=allow_stale,
15361502
for_update=for_update,
15371503
eager_loading_op=_eager_loading_op,
@@ -1547,8 +1513,8 @@ async def get_session_by_id(
15471513
cls,
15481514
db_session: SASession,
15491515
session_id: SessionId,
1550-
access_key: AccessKey | None = None,
15511516
*,
1517+
owner_id: UUID | None = None,
15521518
max_matches: int | None = None,
15531519
allow_stale: bool = True,
15541520
for_update: bool = False,
@@ -1557,7 +1523,7 @@ async def get_session_by_id(
15571523
sessions = await _match_sessions_by_id(
15581524
db_session,
15591525
session_id,
1560-
access_key,
1526+
owner_id=owner_id,
15611527
max_matches=max_matches,
15621528
allow_stale=allow_stale,
15631529
for_update=for_update,
@@ -1586,7 +1552,6 @@ async def get_sgroup_managed_sessions(
15861552
noload("*"),
15871553
selectinload(SessionRow.group).options(noload("*")),
15881554
selectinload(SessionRow.domain).options(noload("*")),
1589-
selectinload(SessionRow.access_key_row).options(noload("*")),
15901555
selectinload(SessionRow.kernels).options(noload("*")),
15911556
)
15921557
)

0 commit comments

Comments
 (0)