Skip to content

Commit f4fbdd7

Browse files
committed
refactor(BA-5715): address slice G review feedback
- Revert stray changes in repositories/model_serving/repository.py (DeploymentStorageSource import / model-definition fetch removal were unrelated to BA-5650). - services/session/service.py: _resolve_owner_main_access_key now uses the narrower UserRepository.get_main_access_key_by_id helper instead of loading the entire UserData record. - services/session/lifecycle.py: guard the POST_START_SESSION hook payload against a None main_access_key — log a warning and skip the hook rather than pass None into plugins that expect an AccessKey.
1 parent db5d501 commit f4fbdd7

3 files changed

Lines changed: 88 additions & 16 deletions

File tree

src/ai/backend/manager/repositories/model_serving/repository.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import sqlalchemy as sa
77
from pydantic import HttpUrl
8+
from ruamel.yaml import YAML
89
from sqlalchemy.exc import IntegrityError, NoResultFound, StatementError
910
from sqlalchemy.ext.asyncio import AsyncSession as SASession
1011
from sqlalchemy.orm import selectinload
@@ -39,7 +40,7 @@
3940
UserData,
4041
)
4142
from ai.backend.manager.data.permission.types import RBACElementRef
42-
from ai.backend.manager.data.vfolder.types import VFolderOwnershipType
43+
from ai.backend.manager.data.vfolder.types import VFolderLocation, VFolderOwnershipType
4344
from ai.backend.manager.errors.common import ObjectNotFound
4445
from ai.backend.manager.errors.resource import DatabaseConnectionUnavailable
4546
from ai.backend.manager.errors.service import EndpointNotFound
@@ -80,6 +81,9 @@
8081
execute_rbac_entity_creator,
8182
)
8283
from ai.backend.manager.repositories.deployment.creators import DeploymentPolicyCreatorSpec
84+
from ai.backend.manager.repositories.deployment.storage_source.storage_source import (
85+
DeploymentStorageSource,
86+
)
8387
from ai.backend.manager.repositories.model_serving.updaters import EndpointUpdaterSpec
8488
from ai.backend.manager.services.model_serving.actions.modify_endpoint import ModifyEndpointAction
8589
from ai.backend.manager.services.model_serving.exceptions import (
@@ -734,7 +738,7 @@ async def get_session_by_id(
734738
async with self._db.begin_readonly_session_read_committed() as session:
735739
try:
736740
return await SessionRow.get_session(
737-
session, session_id, kernel_loading_strategy=kernel_loading_strategy
741+
session, session_id, None, kernel_loading_strategy=kernel_loading_strategy
738742
)
739743
except NoResultFound:
740744
return None
@@ -828,6 +832,15 @@ async def _do_mutate() -> MutationResult:
828832
if current_rev is None:
829833
raise InvalidAPIParameters("Endpoint has no current revision")
830834

835+
# Re-read model definition from vfolder to pick up file changes
836+
refreshed_model_definition = await self._fetch_model_definition_from_vfolder(
837+
db_session,
838+
storage_manager,
839+
current_rev.model,
840+
spec.model_definition_path.optional_value()
841+
or current_rev.model_definition_path,
842+
)
843+
831844
# Resolve image if changed
832845
image_id = current_rev.image
833846
image_ref = spec.image.optional_value()
@@ -860,6 +873,7 @@ async def _do_mutate() -> MutationResult:
860873
if spec.model_definition_path.optional_value() is not None
861874
else current_rev.model_definition_path
862875
),
876+
model_definition=refreshed_model_definition or current_rev.model_definition,
863877
resource_group=endpoint_row.resource_group,
864878
resource_opts=(
865879
spec.resource_opts.optional_value()
@@ -941,6 +955,51 @@ async def _do_mutate() -> MutationResult:
941955
except Exception:
942956
raise
943957

958+
async def _fetch_model_definition_from_vfolder(
959+
self,
960+
db_session: SASession,
961+
storage_manager: StorageSessionManager,
962+
vfolder_id: uuid.UUID | None,
963+
model_definition_path: str | None,
964+
) -> dict[str, Any] | None:
965+
"""Re-read model definition file from the vfolder storage.
966+
967+
Returns the parsed YAML content, or None if the file cannot be read.
968+
"""
969+
if vfolder_id is None:
970+
return None
971+
try:
972+
vf_query = sa.select(
973+
VFolderRow.id,
974+
VFolderRow.host,
975+
VFolderRow.quota_scope_id,
976+
VFolderRow.ownership_type,
977+
VFolderRow.usage_mode,
978+
).where(VFolderRow.id == vfolder_id)
979+
vf_result = await db_session.execute(vf_query)
980+
vf_row = vf_result.one_or_none()
981+
if vf_row is None:
982+
return None
983+
984+
vfolder_location = VFolderLocation(
985+
id=vf_row.id,
986+
quota_scope_id=vf_row.quota_scope_id,
987+
host=vf_row.host,
988+
ownership_type=vf_row.ownership_type,
989+
usage_mode=vf_row.usage_mode,
990+
)
991+
candidates = (
992+
[model_definition_path]
993+
if model_definition_path
994+
else ["model-definition.yaml", "model-definition.yml"]
995+
)
996+
storage_source = DeploymentStorageSource(storage_manager)
997+
content = await storage_source.fetch_definition_file(vfolder_location, candidates)
998+
yaml = YAML()
999+
return cast(dict[str, Any], yaml.load(content))
1000+
except Exception:
1001+
return None
1002+
9441003
@model_serving_repository_resilience.apply()
9451004
async def search_auto_scaling_rules(
9461005
self,

src/ai/backend/manager/services/session/lifecycle.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from ai.backend.common.plugin.hook import HookPluginContext
2525
from ai.backend.common.types import (
26+
AccessKey,
2627
ResourceSlot,
2728
SessionId,
2829
SessionTypes,
@@ -136,18 +137,29 @@ async def _post_status_transition(
136137
SessionStartedAnycastEvent(session_row.id, creation_id)
137138
)
138139
# BA-5609: resolve main_access_key from owner_id; external
139-
# hook plugins still receive the resolved access key.
140+
# hook plugins still receive the resolved access key. If the
141+
# owner has no main_access_key configured, skip the hook —
142+
# calling it with ``None`` would likely break plugins that
143+
# assume a non-null keypair identifier.
140144
session_main_access_key = await self._user_repository.get_main_access_key_by_id(
141145
session_row.user_uuid
142146
)
143-
await self.hook_plugin_ctx.notify(
144-
"POST_START_SESSION",
145-
(
147+
if session_main_access_key is not None:
148+
await self.hook_plugin_ctx.notify(
149+
"POST_START_SESSION",
150+
(
151+
session_row.id,
152+
session_row.name,
153+
AccessKey(session_main_access_key),
154+
),
155+
)
156+
else:
157+
log.warning(
158+
"POST_START_SESSION skipped: owner {} has no main_access_key"
159+
" (session {})",
160+
session_row.user_uuid,
146161
session_row.id,
147-
session_row.name,
148-
session_main_access_key,
149-
),
150-
)
162+
)
151163
match session_row.session_type:
152164
case SessionTypes.BATCH:
153165
await self.registry.trigger_batch_execution(session_row)

src/ai/backend/manager/services/session/service.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,17 @@ async def _resolve_owner_main_access_key(
288288
) -> AccessKey:
289289
"""Resolve a delegated owner UUID to that user's main access key.
290290
291-
Loads the target user via the user repository and returns the main
292-
access key. Raises ``InternalServerError`` if the target user has no
293-
main access key configured.
291+
Uses the narrower ``UserRepository.get_main_access_key_by_id`` helper
292+
so we only fetch the single scalar column we need. Raises
293+
``InternalServerError`` if the target user has no main access key
294+
configured.
294295
"""
295-
user_data = await self._user_repository.get_user_by_uuid(owner_id)
296-
if user_data.main_access_key is None:
296+
main_access_key = await self._user_repository.get_main_access_key_by_id(owner_id)
297+
if main_access_key is None:
297298
raise InternalServerError(
298299
f"Delegated owner {owner_id} has no main access key configured"
299300
)
300-
return AccessKey(user_data.main_access_key)
301+
return AccessKey(main_access_key)
301302

302303
async def commit_session(self, action: CommitSessionAction) -> CommitSessionActionResult:
303304
session_name = action.session_name

0 commit comments

Comments
 (0)