Skip to content
Merged
1 change: 1 addition & 0 deletions changes/10271.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use deploying-revision image for new route session creation
11 changes: 7 additions & 4 deletions src/ai/backend/manager/data/deployment/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ class ExecutionSpec(ConfiguredModel):


class ModelRevisionSpec(ConfiguredModel):
revision_id: UUID | None = None
image_identifier: ImageIdentifier
resource_spec: ResourceSpec
mounts: MountMetadata
Expand Down Expand Up @@ -379,10 +380,12 @@ class DeploymentInfo:
deploying_revision_id: UUID | None = None
sub_step: DeploymentSubStep | None = None

def target_revision(self) -> ModelRevisionSpec | None:
if self.model_revisions:
return self.model_revisions[0]
return None
def resolve_revision_spec(self, revision_id: UUID) -> ModelRevisionSpec | None:
"""Find a ModelRevisionSpec by revision_id from model_revisions."""
return next(
(r for r in self.model_revisions if r.revision_id == revision_id),
None,
)


@dataclass
Expand Down
82 changes: 44 additions & 38 deletions src/ai/backend/manager/models/endpoint/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
relationship,
selectinload,
)
from sqlalchemy.orm.attributes import instance_state

from ai.backend.common.config import model_definition_iv
from ai.backend.common.types import (
Expand Down Expand Up @@ -779,14 +780,16 @@ def to_deployment_info(self) -> DeploymentInfo:
if self.deployment_policy is not None:
policy_data = self.deployment_policy.to_data()

# Try to use current revision if available
if self.current_revision and hasattr(self, "revisions") and self.revisions:
current_rev = next(
(r for r in self.revisions if r.id == self.current_revision),
None,
)
if current_rev:
info = self._to_deployment_info_from_revision(current_rev)
# Build model_revisions list from loaded revision rows
if "revisions" in instance_state(self).dict and self.revisions:
model_revisions: list[ModelRevisionSpec] = []
for rev_row in self.revisions:
if rev_row.image_row is None:
continue
if rev_row.id == self.current_revision or rev_row.id == self.deploying_revision:
model_revisions.append(self._build_revision_spec(rev_row))
if model_revisions:
info = self._to_deployment_info_with_revisions(model_revisions)
Comment on lines +783 to +792
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code for checking the string key doesn't look great, but since it seems to be an existing issue, I'll leave it for now.

info.policy = policy_data
return info

Expand All @@ -795,16 +798,44 @@ def to_deployment_info(self) -> DeploymentInfo:
info.policy = policy_data
return info

def _to_deployment_info_from_revision(
def _build_revision_spec(
self,
revision: DeploymentRevisionRow,
) -> DeploymentInfo:
"""Build DeploymentInfo using revision data."""
# Get image identifier from revision's image_row
) -> ModelRevisionSpec:
"""Build a ModelRevisionSpec from a revision row."""
image_identifier = ImageIdentifier(
canonical=revision.image_row.name,
architecture=revision.image_row.architecture,
)
return ModelRevisionSpec(
revision_id=revision.id,
image_identifier=image_identifier,
resource_spec=ResourceSpec(
cluster_mode=ClusterMode(revision.cluster_mode),
cluster_size=revision.cluster_size,
resource_slots=revision.resource_slots,
resource_opts=revision.resource_opts,
),
mounts=MountMetadata(
model_vfolder_id=revision.model or uuid.UUID(int=0),
model_definition_path=revision.model_definition_path,
model_mount_destination=revision.model_mount_destination,
extra_mounts=revision.extra_mounts or [],
),
execution=ExecutionSpec(
startup_command=revision.startup_command,
bootstrap_script=revision.bootstrap_script,
environ=revision.environ,
runtime_variant=revision.runtime_variant,
callback_url=yarl.URL(revision.callback_url) if revision.callback_url else None,
),
)

def _to_deployment_info_with_revisions(
self,
model_revisions: Sequence[ModelRevisionSpec],
) -> DeploymentInfo:
"""Build DeploymentInfo with pre-built model_revisions dict."""
return DeploymentInfo(
id=self.id,
metadata=DeploymentMetadata(
Expand All @@ -830,32 +861,7 @@ def _to_deployment_info_from_revision(
open_to_public=self.open_to_public if self.open_to_public is not None else False,
url=self.url,
),
model_revisions=[
ModelRevisionSpec(
image_identifier=image_identifier,
resource_spec=ResourceSpec(
cluster_mode=ClusterMode(revision.cluster_mode),
cluster_size=revision.cluster_size,
resource_slots=revision.resource_slots,
resource_opts=revision.resource_opts,
),
mounts=MountMetadata(
model_vfolder_id=revision.model or uuid.UUID(int=0),
model_definition_path=revision.model_definition_path,
model_mount_destination=revision.model_mount_destination,
extra_mounts=revision.extra_mounts or [],
),
execution=ExecutionSpec(
startup_command=revision.startup_command,
bootstrap_script=revision.bootstrap_script,
environ=revision.environ,
runtime_variant=revision.runtime_variant,
callback_url=yarl.URL(revision.callback_url)
if revision.callback_url
else None,
),
),
],
model_revisions=list(model_revisions),
current_revision_id=self.current_revision,
deploying_revision_id=self.deploying_revision,
sub_step=self.sub_step,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ async def get_routes_by_endpoint(
status=row.status,
traffic_ratio=row.traffic_ratio,
created_at=row.created_at,
revision_id=row.revision,
error_data=row.error_data or {},
)
for row in rows
Expand Down Expand Up @@ -1557,6 +1558,7 @@ async def get_routes_by_statuses(
status=row.status,
traffic_ratio=row.traffic_ratio,
created_at=row.created_at,
revision_id=row.revision,
error_data=row.error_data or {},
)
route_data_list.append(route_data)
Expand Down Expand Up @@ -1788,11 +1790,13 @@ async def delete_routes_by_route_ids(
async def fetch_deployment_context(
self,
deployment_info: DeploymentInfo,
revision_id: uuid.UUID,
) -> DeploymentContext:
"""Fetch all context data needed for session creation from deployment info.

Args:
deployment_info: Deployment information
revision_id: Revision to use for image resolution.

Returns:
DeploymentContext: Context data needed for session creation
Expand Down Expand Up @@ -1843,15 +1847,22 @@ async def fetch_deployment_context(
else None,
)

# Resolve image
target_revision = deployment_info.target_revision()
if not target_revision:
raise DeploymentHasNoTargetRevision("Deployment has no target revision")

image_row = await ImageRow.resolve(
db_sess,
[target_revision.image_identifier],
revision_query = (
sa.select(DeploymentRevisionRow)
.where(DeploymentRevisionRow.id == revision_id)
.options(selectinload(DeploymentRevisionRow.image_row))
)
revision_result = await db_sess.execute(revision_query)
revision_row = revision_result.scalar_one_or_none()
if revision_row is None or revision_row.image_row is None:
raise DeploymentHasNoTargetRevision(
f"Revision {revision_id} not found or has no image"
)
image_identifier = ImageIdentifier(
canonical=revision_row.image_row.name,
architecture=revision_row.image_row.architecture,
)
image_row = await ImageRow.resolve(db_sess, [image_identifier])

# Build DeploymentContext
return DeploymentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,16 +707,18 @@ async def delete_routes_by_route_ids(
async def fetch_deployment_context(
self,
deployment_info: DeploymentInfo,
revision_id: UUID,
) -> DeploymentContext:
"""Fetch all context data needed for session creation from deployment info.

Args:
deployment_info: Deployment information
revision_id: Revision to use for image resolution.

Returns:
DeploymentContext: Context data needed for session creation
"""
return await self._db_source.fetch_deployment_context(deployment_info)
return await self._db_source.fetch_deployment_context(deployment_info, revision_id)

# Auto-scaling operations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class RouteData:
status: RouteStatus
traffic_ratio: float
created_at: datetime
revision_id: uuid.UUID | None = None
updated_at: datetime | None = None
error_data: dict[str, Any] = field(default_factory=dict)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,17 @@ class SessionCreationSpec:

@classmethod
def from_deployment_info(
cls, deployment_info: DeploymentInfo, context: DeploymentContext, route_id: UUID
cls,
deployment_info: DeploymentInfo,
context: DeploymentContext,
route_id: UUID,
revision_id: UUID,
) -> Self:
session_creation_id = secrets.token_urlsafe(16)
target_revision = deployment_info.target_revision()
target_revision = deployment_info.resolve_revision_spec(revision_id)
if target_revision is None:
raise DeploymentHasNoTargetRevision(
"Deployment has no target revision for session creation"
f"Revision {revision_id} not found in model_revisions"
)

# Prepare mount spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,14 @@ async def update_deployment(
modified_endpoint = await self._deployment_repository.get_modified_endpoint(
endpoint_id=endpoint_id, updater=updater
)
target_revision = modified_endpoint.target_revision()
if target_revision:
await self._scheduling_controller.validate_session_spec(
SessionValidationSpec.from_revision(model_revision=target_revision)
if modified_endpoint.current_revision_id is not None:
current_revision = modified_endpoint.resolve_revision_spec(
modified_endpoint.current_revision_id
)
if current_revision:
await self._scheduling_controller.validate_session_spec(
SessionValidationSpec.from_revision(model_revision=current_revision)
)
res = await self._deployment_repository.update_endpoint_with_spec(updater)
try:
await self.mark_lifecycle_needed(DeploymentLifecycleType.CHECK_REPLICA)
Expand Down
27 changes: 19 additions & 8 deletions src/ai/backend/manager/sokovan/deployment/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,20 @@ async def check_pending_deployments(
valid_deployments: list[DeploymentWithHistory] = []
for deployment in deployments:
info = deployment.deployment_info
target_revision = info.target_revision()
if not target_revision:
if info.current_revision_id is None:
log.warning(
"Deployment {} has no target revision, skipping",
"Deployment {} has no current revision, skipping",
info.id,
)
continue
current_revision = info.resolve_revision_spec(info.current_revision_id)
if not current_revision:
log.warning(
"Deployment {} current revision {} not found in model_revisions, skipping",
info.id,
info.current_revision_id,
)
continue
targets = scaling_group_targets[info.metadata.resource_group]
if not targets:
log.warning(
Expand Down Expand Up @@ -443,16 +450,20 @@ async def _register_endpoint(

with recorder.phase("register_endpoint"):
with recorder.step("check_target_revision"):
target_revision = deployment.target_revision()
if not target_revision:
if deployment.current_revision_id is None:
raise ModelDefinitionNotFound(
f"No current revision for deployment {deployment.id}"
)
current_revision = deployment.resolve_revision_spec(deployment.current_revision_id)
if not current_revision:
raise ModelDefinitionNotFound(
f"No target revision for deployment {deployment.id}"
f"Current revision {deployment.current_revision_id} not found for deployment {deployment.id}"
)

with recorder.step("generate_model_definition"):
model_definition = (
await self._model_definition_generator_registry.generate_model_definition(
target_revision
current_revision
)
)
health_check_config = model_definition.health_check_config()
Expand All @@ -469,7 +480,7 @@ async def _register_endpoint(
session_owner_id=deployment.metadata.session_owner,
project_id=deployment.metadata.project,
domain_name=deployment.metadata.domain,
runtime_variant=target_revision.execution.runtime_variant,
runtime_variant=current_revision.execution.runtime_variant,
existing_url=deployment.network.url,
open_to_public=deployment.network.open_to_public,
health_check_config=health_check_config,
Expand Down
15 changes: 14 additions & 1 deletion src/ai/backend/manager/sokovan/deployment/route/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.data.deployment.types import DeploymentInfo, RouteStatus
from ai.backend.manager.errors.deployment import (
DeploymentHasNoTargetRevision,
EndpointNotFound,
RouteSessionNotFound,
RouteSessionTerminated,
Expand Down Expand Up @@ -398,9 +399,20 @@ async def _provision_route(
if deployment is None:
raise EndpointNotFound(f"Deployment not found for endpoint {route.endpoint_id}")

target_revision_id = (
route.revision_id
or deployment.deploying_revision_id
or deployment.current_revision_id
)
if target_revision_id is None:
raise DeploymentHasNoTargetRevision(
"Deployment has no revision for image resolution"
)

# Fetch deployment context with all necessary data
deployment_context = await self._deployment_repo.fetch_deployment_context(
deployment
deployment,
revision_id=target_revision_id,
)

# Create session with full context
Expand All @@ -409,6 +421,7 @@ async def _provision_route(
deployment_info=deployment,
context=deployment_context,
route_id=route.route_id,
revision_id=target_revision_id,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def _create_deployment_info(
) -> DeploymentInfo:
"""Create DeploymentInfo for tests."""
dep_id = deployment_id or uuid4()
rev_id = uuid4()
revision = MagicMock() if has_revision else None
if revision is not None:
revision.revision_id = rev_id

return DeploymentInfo(
id=dep_id,
Expand All @@ -133,7 +136,7 @@ def _create_deployment_info(
url=None,
),
model_revisions=[revision] if has_revision else [], # type: ignore[list-item]
current_revision_id=uuid4() if has_revision else None,
current_revision_id=rev_id if has_revision else None,
)


Expand Down
Loading
Loading