@@ -700,7 +700,7 @@ async def _fetch_user_policies(
700700 """Fetch user resource policies for users in pending sessions."""
701701 user_policies : dict [UUID , UserResourcePolicy ] = {}
702702
703- if not pending_sessions .user_uuids :
703+ if not pending_sessions .owner_ids :
704704 return user_policies
705705
706706 user_policy_result = await db_sess .execute (
@@ -716,7 +716,7 @@ async def _fetch_user_policies(
716716 KeyPairResourcePolicyRow ,
717717 KeyPairRow .resource_policy == KeyPairResourcePolicyRow .name ,
718718 )
719- .where (UserRow .uuid .in_ (pending_sessions .user_uuids ))
719+ .where (UserRow .uuid .in_ (pending_sessions .owner_ids ))
720720 )
721721
722722 for row in user_policy_result :
@@ -1146,11 +1146,14 @@ async def get_terminating_sessions_by_ids(
11461146 for kernel in session_row .kernels
11471147 ]
11481148
1149+ owner_main_ak = (
1150+ session_row .user .main_access_key if session_row .user else None
1151+ )
11491152 terminating_sessions .append (
11501153 TerminatingSessionData (
11511154 session_id = session_row .id ,
1152- access_key = AccessKey (session_row . access_key )
1153- if session_row . access_key
1155+ main_access_key = AccessKey (owner_main_ak )
1156+ if owner_main_ak
11541157 else AccessKey ("" ),
11551158 creation_id = session_row .creation_id or "" ,
11561159 status = session_row .status ,
@@ -1183,11 +1186,12 @@ async def get_pending_timeout_sessions_by_ids(
11831186 sa .select (
11841187 SessionRow .id ,
11851188 SessionRow .creation_id ,
1186- SessionRow . access_key ,
1189+ UserRow . main_access_key ,
11871190 SessionRow .created_at ,
11881191 ScalingGroupRow .scheduler_opts ,
11891192 )
11901193 .select_from (SessionRow )
1194+ .join (UserRow , SessionRow .user_uuid == UserRow .uuid )
11911195 .join (ScalingGroupRow , SessionRow .scaling_group_name == ScalingGroupRow .name )
11921196 .where (
11931197 SessionRow .id .in_ (session_ids ),
@@ -1213,7 +1217,7 @@ async def get_pending_timeout_sessions_by_ids(
12131217 SweptSessionInfo (
12141218 session_id = row .id ,
12151219 creation_id = row .creation_id ,
1216- access_key = row .access_key ,
1220+ main_access_key = row .main_access_key ,
12171221 )
12181222 )
12191223
@@ -1302,8 +1306,8 @@ async def enqueue_session(
13021306 id = session_data .id ,
13031307 creation_id = session_data .creation_id ,
13041308 name = session_data .name ,
1305- access_key = session_data .access_key ,
1306- user_uuid = session_data .user_uuid ,
1309+ access_key = session_data .main_access_key ,
1310+ user_uuid = session_data .owner_id ,
13071311 group_id = session_data .group_id ,
13081312 domain_name = session_data .domain_name ,
13091313 scaling_group_name = session_data .scaling_group_name ,
@@ -1349,8 +1353,8 @@ async def enqueue_session(
13491353 scaling_group = kernel .scaling_group ,
13501354 domain_name = kernel .domain_name ,
13511355 group_id = kernel .group_id ,
1352- user_uuid = kernel .user_uuid ,
1353- access_key = kernel .access_key ,
1356+ user_uuid = kernel .owner_id ,
1357+ access_key = kernel .main_access_key ,
13541358 image = kernel .image ,
13551359 architecture = kernel .architecture ,
13561360 registry = kernel .registry ,
@@ -1387,7 +1391,7 @@ async def enqueue_session(
13871391 element_type = RBACElementType .SESSION ,
13881392 scope_ref = RBACElementRef (
13891393 element_type = RBACElementType .USER ,
1390- element_id = str (session_data .user_uuid ),
1394+ element_id = str (session_data .owner_id ),
13911395 ),
13921396 additional_scope_refs = [
13931397 RBACElementRef (
@@ -1843,21 +1847,28 @@ async def allocate_sessions(
18431847 # First, fetch session data to get creation_id and access_key
18441848 session_ids = {alloc .session_id for alloc in allocation_batch .allocations }
18451849 if session_ids :
1846- query = sa .select (
1847- SessionRow .id , SessionRow .creation_id , SessionRow .access_key
1848- ).where (SessionRow .id .in_ (session_ids ))
1850+ query = (
1851+ sa .select (
1852+ SessionRow .id , SessionRow .creation_id , UserRow .main_access_key
1853+ )
1854+ .select_from (SessionRow )
1855+ .join (UserRow , SessionRow .user_uuid == UserRow .uuid )
1856+ .where (SessionRow .id .in_ (session_ids ))
1857+ )
18491858 result = await db_sess .execute (query )
1850- session_data_map = {row .id : (row .creation_id , row .access_key ) for row in result }
1859+ session_data_map = {
1860+ row .id : (row .creation_id , row .main_access_key ) for row in result
1861+ }
18511862
18521863 # Create SessionEventData for each allocated session
18531864 for allocation in allocation_batch .allocations :
18541865 if session_data := session_data_map .get (allocation .session_id ):
1855- creation_id , access_key = session_data
1866+ creation_id , main_access_key = session_data
18561867 scheduled_sessions .append (
18571868 ScheduledSessionData (
18581869 session_id = allocation .session_id ,
18591870 creation_id = creation_id ,
1860- access_key = access_key ,
1871+ main_access_key = main_access_key ,
18611872 reason = "triggered-by-scheduler" ,
18621873 )
18631874 )
@@ -2917,7 +2928,9 @@ async def _get_sessions_by_statuses(
29172928 scheduled_session = ScheduledSessionData (
29182929 session_id = session .id ,
29192930 creation_id = session .creation_id or "" ,
2920- access_key = AccessKey (session .access_key ) if session .access_key else AccessKey ("" ),
2931+ main_access_key = AccessKey (session .access_key )
2932+ if session .access_key
2933+ else AccessKey ("" ),
29212934 reason = "triggered-by-scheduler" ,
29222935 )
29232936 scheduled_sessions .append (scheduled_session )
@@ -2962,7 +2975,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes
29622975 ScheduledSessionData (
29632976 session_id = session .id ,
29642977 creation_id = session .creation_id or "" ,
2965- access_key = AccessKey (session .access_key )
2978+ main_access_key = AccessKey (session .access_key )
29662979 if session .access_key
29672980 else AccessKey ("" ),
29682981 reason = "triggered-by-scheduler" ,
@@ -3102,7 +3115,7 @@ async def _get_sessions_for_pull(
31023115 sessions_map [session_id ] = SessionDataForPull (
31033116 session_id = session_id ,
31043117 creation_id = row .creation_id ,
3105- access_key = row .access_key ,
3118+ main_access_key = row .access_key ,
31063119 kernels = [],
31073120 )
31083121
@@ -3294,13 +3307,13 @@ async def _get_sessions_for_start(
32943307 SessionDataForStart (
32953308 session_id = session_info ["id" ],
32963309 creation_id = session_info ["creation_id" ],
3297- access_key = session_info ["access_key" ],
3310+ main_access_key = session_info ["access_key" ],
32983311 session_type = session_info ["session_type" ],
32993312 name = session_info ["name" ],
33003313 cluster_mode = session_info ["cluster_mode" ],
33013314 kernels = kernel_bindings ,
33023315 environ = session_info .get ("environ" , {}),
3303- user_uuid = session_info ["user_uuid" ],
3316+ owner_id = session_info ["user_uuid" ],
33043317 user_email = user_info .email ,
33053318 user_name = user_info .username ,
33063319 )
@@ -4074,7 +4087,7 @@ async def _fetch_sessions_for_pull_by_ids(
40744087 sessions_map [session_id ] = SessionDataForPull (
40754088 session_id = session_id ,
40764089 creation_id = row .creation_id ,
4077- access_key = row .access_key ,
4090+ main_access_key = row .access_key ,
40784091 kernels = [],
40794092 )
40804093
@@ -4293,13 +4306,13 @@ async def _fetch_sessions_for_start_by_ids(
42934306 SessionDataForStart (
42944307 session_id = session_info ["id" ],
42954308 creation_id = session_info ["creation_id" ],
4296- access_key = session_info ["access_key" ],
4309+ main_access_key = session_info ["access_key" ],
42974310 session_type = session_info ["session_type" ],
42984311 name = session_info ["name" ],
42994312 cluster_mode = session_info ["cluster_mode" ],
43004313 kernels = kernel_bindings ,
43014314 environ = session_info .get ("environ" , {}),
4302- user_uuid = session_info ["user_uuid" ],
4315+ owner_id = session_info ["user_uuid" ],
43034316 user_email = user_info .email ,
43044317 user_name = user_info .username ,
43054318 )
@@ -4369,7 +4382,7 @@ async def search_sessions_with_kernels(
43694382 sessions_map [row .id ] = SessionDataForPull (
43704383 session_id = row .id ,
43714384 creation_id = row .creation_id ,
4372- access_key = row .access_key ,
4385+ main_access_key = row .access_key ,
43734386 kernels = [],
43744387 )
43754388
@@ -4625,13 +4638,13 @@ async def search_sessions_with_kernels_and_user(
46254638 SessionDataForStart (
46264639 session_id = session_info ["id" ],
46274640 creation_id = session_info ["creation_id" ],
4628- access_key = session_info ["access_key" ],
4641+ main_access_key = session_info ["access_key" ],
46294642 session_type = session_info ["session_type" ],
46304643 name = session_info ["name" ],
46314644 cluster_mode = session_info ["cluster_mode" ],
46324645 kernels = session_info ["kernels" ],
46334646 environ = session_info .get ("environ" ) or {},
4634- user_uuid = session_info ["user_uuid" ],
4647+ owner_id = session_info ["user_uuid" ],
46354648 user_email = user_info .email ,
46364649 user_name = user_info .username ,
46374650 )
0 commit comments