8686)
8787from ai .backend .manager .repositories .base .creator import BulkCreator
8888from ai .backend .manager .repositories .base .rbac .entity_creator import (
89+ RBACBulkEntityCreator ,
8990 RBACEntityCreator ,
91+ execute_rbac_bulk_entity_creator ,
9092 execute_rbac_entity_creator ,
9193)
9294from ai .backend .manager .repositories .base .updater import BatchUpdater , execute_batch_updater
130132 SessionSchedulingHistoryCreatorSpec ,
131133)
132134from ai .backend .manager .repositories .session .creators import (
135+ KernelRowCreatorSpec ,
133136 SessionRowCreatorSpec ,
134137)
135138from ai .backend .manager .sokovan .data import (
@@ -1295,7 +1298,7 @@ async def enqueue_session(
12951298 )
12961299
12971300 # Create kernel rows
1298- kernels = []
1301+ kernel_specs = []
12991302 for kernel in session_data .kernels :
13001303 kernel_row = KernelRow (
13011304 id = kernel .id ,
@@ -1342,7 +1345,7 @@ async def enqueue_session(
13421345 main_gid = kernel .main_gid ,
13431346 gids = kernel .gids ,
13441347 )
1345- kernels .append (kernel_row )
1348+ kernel_specs .append (KernelRowCreatorSpec ( row = kernel_row ) )
13461349
13471350 # Use RBACEntityCreator to create session with RBAC scope association
13481351 rbac_creator = RBACEntityCreator (
@@ -1361,11 +1364,19 @@ async def enqueue_session(
13611364 )
13621365 await execute_rbac_entity_creator (db_sess , rbac_creator )
13631366
1364- db_sess .add_all (kernels )
1365- await db_sess .flush ()
1367+ # Use RBACBulkEntityCreator to create kernels with RBAC scope association
1368+ kernel_rbac_creator = RBACBulkEntityCreator (
1369+ specs = kernel_specs ,
1370+ element_type = RBACElementType .KERNEL ,
1371+ scope_ref = RBACElementRef (
1372+ element_type = RBACElementType .SESSION ,
1373+ element_id = str (session_data .id ),
1374+ ),
1375+ )
1376+ kernel_result = await execute_rbac_bulk_entity_creator (db_sess , kernel_rbac_creator )
13661377
13671378 # Record requested resources in normalized resource_allocations table
1368- for kernel_row in kernels :
1379+ for kernel_row in kernel_result . rows :
13691380 quantities = resource_slot_to_quantities (kernel_row .requested_slots )
13701381 if quantities :
13711382 await db_sess .execute (
0 commit comments