Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/10542.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `initial_revision` optional in deployment creation and remove duplicate revision columns from `endpoints` table in favor of `deployment_revisions`
4 changes: 3 additions & 1 deletion src/ai/backend/common/dto/manager/deployment/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ class CreateDeploymentRequest(BaseRequestModel):
description="Default deployment strategy"
)
desired_replica_count: int = Field(ge=0, description="Desired number of replicas")
initial_revision: RevisionInput = Field(description="Initial revision configuration")
initial_revision: RevisionInput | None = Field(
default=None, description="Initial revision configuration"
)


# ========== Deployment Policy Requests ==========
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/api/adapters/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ async def create(
)
model_revision_creator = ModelRevisionCreator(
image_id=initial_revision.image.id,
resource_group=initial_revision.resource_config.resource_group.name,
resource_spec=ResourceSpec(
cluster_mode=initial_revision.cluster_config.mode,
cluster_size=initial_revision.cluster_config.size,
Expand Down Expand Up @@ -793,6 +794,7 @@ async def add_revision(
)
adder = ModelRevisionCreator(
image_id=input.image.id,
resource_group=input.resource_config.resource_group.name,
resource_spec=ResourceSpec(
cluster_mode=input.cluster_config.mode,
cluster_size=input.cluster_config.size,
Expand Down
8 changes: 4 additions & 4 deletions src/ai/backend/manager/api/gql_legacy/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ async def load_all(
project=project,
domain=domain_name,
user_uuid=user_uuid,
load_revisions=True,
load_current_revision=True,
load_created_user=True,
load_session_owner=True,
)
Expand All @@ -842,7 +842,7 @@ async def load_item(
domain=domain_name,
user_uuid=user_uuid,
project=project,
load_revisions=True,
load_current_revision=True,
load_routes=True,
load_created_user=True,
load_session_owner=True,
Expand Down Expand Up @@ -888,8 +888,8 @@ async def resolve_extra_mounts(self, info: graphene.ResolveInfo) -> Sequence[Vir
ctx: GraphQueryContext = info.context

async with ctx.db.begin_readonly_session() as sess:
endpoint_row = await EndpointRow.get(sess, self.endpoint_id, load_revisions=True)
current_rev = endpoint_row._find_current_revision()
endpoint_row = await EndpointRow.get(sess, self.endpoint_id, load_current_revision=True)
current_rev = endpoint_row.current_revision_row
extra_mounts = current_rev.extra_mounts if current_rev else []
extra_mount_folder_ids = [m.vfid.folder_id for m in extra_mounts]
query = (
Expand Down
14 changes: 11 additions & 3 deletions src/ai/backend/manager/api/rest/deployment/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def build_revision_creator(revision_input: RevisionInput) -> ModelRevisionCreato

return ModelRevisionCreator(
image_id=revision_input.image.id,
resource_group=revision_input.resource_config.resource_group,
resource_spec=resource_spec,
mounts=mounts,
execution=execution,
Expand Down Expand Up @@ -463,11 +464,16 @@ def build_creator(
tag = ",".join(request.metadata.tags) if request.metadata.tags else None

# Build metadata
resource_group = (
request.initial_revision.resource_config.resource_group
if request.initial_revision is not None
else ""
)
metadata = DeploymentMetadata(
name=name,
domain=request.metadata.domain_name,
project=request.metadata.project_id,
resource_group=request.initial_revision.resource_config.resource_group,
resource_group=resource_group,
created_user=user_uuid,
session_owner=user_uuid,
created_at=None,
Expand All @@ -484,8 +490,10 @@ def build_creator(
preferred_domain_name=request.network_access.preferred_domain_name,
)

# Build model revision creator
model_revision = build_revision_creator(request.initial_revision)
# Build model revision creator if initial_revision is provided
model_revision: ModelRevisionCreator | None = None
if request.initial_revision is not None:
model_revision = build_revision_creator(request.initial_revision)

# Build policy config
policy = self._build_policy_config(request.default_deployment_strategy)
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/data/deployment/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ModelRevisionCreator:
"""

image_id: UUID
resource_group: str
resource_spec: ResourceSpec
mounts: VFolderMountsCreator
execution: ExecutionSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ async def handle_route_creation(
async with self._db.begin_readonly_session() as db_sess:
log.debug("Route ID: {}", event.route_id)
route = await RoutingRow.get(db_sess, event.route_id)
endpoint = await EndpointRow.get(db_sess, route.endpoint, load_revisions=True)
endpoint = await EndpointRow.get(
db_sess, route.endpoint, load_current_revision=True
)

# Get the current revision for revision-level fields
current_rev = endpoint._find_current_revision()
current_rev = endpoint.current_revision_row
if current_rev is None:
raise ValueError(f"No current revision for endpoint {endpoint.id}")

Expand Down
13 changes: 13 additions & 0 deletions src/ai/backend/manager/models/deployment_revision/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ai.backend.manager.models.endpoint import EndpointRow
from ai.backend.manager.models.image import ImageRow
from ai.backend.manager.models.routing import RoutingRow
from ai.backend.manager.models.vfolder import VFolderRow

__all__ = ("DeploymentRevisionRow",)

Expand All @@ -62,6 +63,12 @@ def _get_image_join_condition() -> sa.sql.elements.ColumnElement[Any]:
return foreign(DeploymentRevisionRow.image) == ImageRow.id


def _get_model_join_condition() -> sa.sql.elements.ColumnElement[Any]:
from ai.backend.manager.models.vfolder import VFolderRow

return foreign(DeploymentRevisionRow.model) == VFolderRow.id


def _get_routings_join_condition() -> sa.sql.elements.ColumnElement[Any]:
from ai.backend.manager.models.routing import RoutingRow

Expand Down Expand Up @@ -181,6 +188,12 @@ class DeploymentRevisionRow(Base): # type: ignore[misc]
"ImageRow",
primaryjoin=_get_image_join_condition,
)
model_row: Mapped[VFolderRow | None] = relationship(
"VFolderRow",
primaryjoin=_get_model_join_condition,
foreign_keys="DeploymentRevisionRow.model",
viewonly=True,
)
routings: Mapped[list[RoutingRow]] = relationship(
"RoutingRow",
primaryjoin=_get_routings_join_condition,
Expand Down
123 changes: 89 additions & 34 deletions src/ai/backend/manager/models/endpoint/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def _get_deployment_policy_join_condition() -> Any:
return EndpointRow.id == foreign(DeploymentPolicyRow.endpoint)


def _get_current_revision_row_join_condition() -> Any:
from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow

return foreign(EndpointRow.current_revision) == DeploymentRevisionRow.id


def _get_created_user_row_join_condition() -> Any:
from ai.backend.manager.models.user import UserRow

Expand Down Expand Up @@ -263,6 +269,14 @@ class EndpointRow(Base): # type: ignore[misc]
endpoint_auto_scaling_rules: Mapped[list[EndpointAutoScalingRuleRow]] = relationship(
"EndpointAutoScalingRuleRow", back_populates="endpoint_row"
)
current_revision_row = relationship(
"DeploymentRevisionRow",
primaryjoin=_get_current_revision_row_join_condition,
uselist=False,
viewonly=True,
foreign_keys="EndpointRow.current_revision",
)

created_user_row: Mapped[UserRow | None] = relationship(
"UserRow",
back_populates="created_endpoints",
Expand Down Expand Up @@ -308,7 +322,7 @@ async def get(
load_tokens: bool = False,
load_created_user: bool = False,
load_session_owner: bool = False,
load_revisions: bool = False,
load_current_revision: bool = False,
) -> Self:
"""
:raises: sqlalchemy.orm.exc.NoResultFound
Expand All @@ -324,9 +338,15 @@ async def get(
query = query.options(selectinload(EndpointRow.created_user_row))
if load_session_owner:
query = query.options(selectinload(EndpointRow.session_owner_row))
if load_revisions:
if load_current_revision:
query = query.options(
selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row)
selectinload(EndpointRow.current_revision_row).selectinload(
DeploymentRevisionRow.image_row
)
).options(
selectinload(EndpointRow.current_revision_row).selectinload(
DeploymentRevisionRow.model_row
)
)
if project:
query = query.filter(EndpointRow.project == project)
Expand All @@ -351,7 +371,7 @@ async def list_endpoint(
load_tokens: bool = False,
load_created_user: bool = False,
load_session_owner: bool = False,
load_revisions: bool = False,
load_current_revision: bool = False,
status_filter: Iterable[EndpointLifecycle] = frozenset([EndpointLifecycle.CREATED]),
) -> list[Self]:
from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow
Expand All @@ -369,9 +389,11 @@ async def list_endpoint(
query = query.options(selectinload(EndpointRow.created_user_row))
if load_session_owner:
query = query.options(selectinload(EndpointRow.session_owner_row))
if load_revisions:
if load_current_revision:
query = query.options(
selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row)
selectinload(EndpointRow.current_revision_row).selectinload(
DeploymentRevisionRow.image_row
)
)
if project:
query = query.filter(EndpointRow.project == project)
Expand Down Expand Up @@ -621,55 +643,88 @@ def _find_current_revision(self) -> DeploymentRevisionRow | None:
def to_data(self) -> EndpointData:
"""Convert to EndpointData.

Requires revisions and revisions.image_row to be eagerly loaded
via selectinload for revision field population.
Uses current_revision_row relationship to populate revision-level fields.
"""
current_rev = self._find_current_revision()
current_revision = self.current_revision_row
routings = [routing.to_data() for routing in self.routings] if self.routings else None

if current_revision is not None:
return EndpointData(
id=self.id,
name=self.name,
image=current_revision.image_row.to_dataclass()
if current_revision.image_row
else None,
domain=self.domain,
project=self.project,
resource_group=self.resource_group,
resource_slots=current_revision.resource_slots,
url=self.url or "",
model=current_revision.model or uuid.UUID(int=0),
model_definition_path=current_revision.model_definition_path,
model_mount_destination=current_revision.model_mount_destination,
created_user_id=self.created_user,
created_user_email=(
self.created_user_row.email if self.created_user_row is not None else None
),
session_owner_id=self.session_owner,
session_owner_email=self.session_owner_row.email if self.session_owner_row else "",
tag=self.tag,
startup_command=current_revision.startup_command,
bootstrap_script=current_revision.bootstrap_script,
callback_url=yarl.URL(current_revision.callback_url)
if current_revision.callback_url
else None,
environ=current_revision.environ,
resource_opts=current_revision.resource_opts,
replicas=self.replicas,
cluster_mode=ClusterMode(current_revision.cluster_mode),
cluster_size=current_revision.cluster_size,
open_to_public=self.open_to_public if self.open_to_public is not None else False,
created_at=self.created_at or datetime.now(UTC),
destroyed_at=self.destroyed_at,
retries=self.retries,
lifecycle_stage=self.lifecycle_stage,
runtime_variant=current_revision.runtime_variant,
extra_mounts=current_revision.extra_mounts,
routings=routings,
)

return EndpointData(
id=self.id,
name=self.name,
image=(
current_rev.image_row.to_dataclass()
if current_rev and current_rev.image_row
else None
),
image=None,
domain=self.domain,
project=self.project,
resource_group=self.resource_group,
resource_slots=current_rev.resource_slots if current_rev else ResourceSlot({}),
resource_slots=ResourceSlot({}),
url=self.url or "",
model=current_rev.model or uuid.UUID(int=0) if current_rev else uuid.UUID(int=0),
model_definition_path=current_rev.model_definition_path if current_rev else None,
model_mount_destination=(current_rev.model_mount_destination if current_rev else None),
model=uuid.UUID(int=0),
model_definition_path=None,
model_mount_destination="/models",
created_user_id=self.created_user,
created_user_email=(
self.created_user_row.email if self.created_user_row is not None else None
),
session_owner_id=self.session_owner,
session_owner_email=self.session_owner_row.email if self.session_owner_row else "",
tag=self.tag,
startup_command=current_rev.startup_command if current_rev else None,
bootstrap_script=current_rev.bootstrap_script if current_rev else None,
callback_url=(
yarl.URL(current_rev.callback_url)
if current_rev and current_rev.callback_url
else None
),
environ=current_rev.environ if current_rev else None,
resource_opts=current_rev.resource_opts if current_rev else None,
startup_command=None,
bootstrap_script=None,
callback_url=None,
environ=None,
resource_opts=None,
replicas=self.replicas,
cluster_mode=(
ClusterMode(current_rev.cluster_mode) if current_rev else ClusterMode.SINGLE_NODE
),
cluster_size=current_rev.cluster_size if current_rev else 1,
cluster_mode=ClusterMode.SINGLE_NODE,
cluster_size=1,
open_to_public=self.open_to_public if self.open_to_public is not None else False,
created_at=self.created_at or datetime.now(UTC),
destroyed_at=self.destroyed_at,
retries=self.retries,
lifecycle_stage=self.lifecycle_stage,
runtime_variant=(current_rev.runtime_variant if current_rev else RuntimeVariant.CUSTOM),
extra_mounts=current_rev.extra_mounts if current_rev else [],
routings=[routing.to_data() for routing in self.routings] if self.routings else None,
runtime_variant=RuntimeVariant.CUSTOM,
extra_mounts=[],
routings=routings,
)

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2533,11 +2533,10 @@ async def notify_endpoint_route_update_to_appproxy(self, endpoint_id: uuid.UUID)
endpoint_id,
load_created_user=True,
load_session_owner=True,
load_revisions=True,
load_routes=True,
)
connection_info = await endpoint.generate_route_info(db_sess)
current_rev = endpoint._find_current_revision()
current_rev = endpoint.current_revision_row
if current_rev is None or current_rev.model is None:
raise InvalidAPIParameters("Model not set for endpoint")
model = await VFolderRow.get(db_sess, current_rev.model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ModelRevisionFields:
"""

image_id: uuid.UUID
resource_group: str
resource: DeploymentResourceFields
mounts: DeploymentMountFields
execution: DeploymentExecutionFields
Expand Down
Loading
Loading