Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/10013.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Apply RBAC Creator pattern to auto sub-entity (BA-5069)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
)
from ai.backend.manager.repositories.base.creator import BulkCreator
from ai.backend.manager.repositories.base.rbac.entity_creator import (
RBACBulkEntityCreator,
RBACEntityCreator,
execute_rbac_bulk_entity_creator,
execute_rbac_entity_creator,
)
from ai.backend.manager.repositories.base.updater import BatchUpdater, execute_batch_updater
Expand Down Expand Up @@ -130,6 +132,7 @@
SessionSchedulingHistoryCreatorSpec,
)
from ai.backend.manager.repositories.session.creators import (
KernelRowCreatorSpec,
SessionRowCreatorSpec,
)
from ai.backend.manager.sokovan.data import (
Expand Down Expand Up @@ -1295,7 +1298,7 @@ async def enqueue_session(
)

# Create kernel rows
kernels = []
kernel_specs = []
for kernel in session_data.kernels:
kernel_row = KernelRow(
id=kernel.id,
Expand Down Expand Up @@ -1342,7 +1345,7 @@ async def enqueue_session(
main_gid=kernel.main_gid,
gids=kernel.gids,
)
kernels.append(kernel_row)
kernel_specs.append(KernelRowCreatorSpec(row=kernel_row))

# Use RBACEntityCreator to create session with RBAC scope association
rbac_creator = RBACEntityCreator(
Expand All @@ -1361,11 +1364,19 @@ async def enqueue_session(
)
await execute_rbac_entity_creator(db_sess, rbac_creator)

db_sess.add_all(kernels)
await db_sess.flush()
# Use RBACBulkEntityCreator to create kernels with RBAC scope association
kernel_rbac_creator = RBACBulkEntityCreator(
specs=kernel_specs,
element_type=RBACElementType.KERNEL,
scope_ref=RBACElementRef(
element_type=RBACElementType.SESSION,
element_id=str(session_data.id),
),
Comment thread
fregataa marked this conversation as resolved.
)
kernel_result = await execute_rbac_bulk_entity_creator(db_sess, kernel_rbac_creator)

# Record requested resources in normalized resource_allocations table
for kernel_row in kernels:
for kernel_row in kernel_result.rows:
quantities = resource_slot_to_quantities(kernel_row.requested_slots)
if quantities:
await db_sess.execute(
Expand Down
19 changes: 19 additions & 0 deletions src/ai/backend/manager/repositories/session/creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import override

from ai.backend.manager.models.kernel import KernelRow
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.repositories.base.creator import CreatorSpec

Expand All @@ -25,3 +26,21 @@ class SessionRowCreatorSpec(CreatorSpec[SessionRow]):
@override
def build_row(self) -> SessionRow:
return self.row


@dataclass
class KernelRowCreatorSpec(CreatorSpec[KernelRow]):
"""CreatorSpec that wraps a pre-built KernelRow.

This spec is designed for retrofitting existing code that already builds
KernelRow instances. It simply returns the provided row in build_row().

Kernel is a sub-entity of Session. Use RBACElementType.SESSION as
the scope_ref element_type when creating the RBAC association.
"""

row: KernelRow

@override
def build_row(self) -> KernelRow:
return self.row
Loading