11import asyncio
22import json
3- from sqlalchemy import delete , text , update
3+ from typing import AsyncGenerator
4+ from sqlalchemy import Select , and_ , delete , select , text , update
45from sqlalchemy .ext .asyncio import AsyncSession
56import logging
67
1415 User ,
1516 UserClassRole ,
1617 UserInstitutionRole ,
18+ _get_upsert_stmt ,
1719 user_thread_association ,
1820 user_merge_association ,
1921 File ,
2022)
23+ from pingpong .schemas import MergedUserTuple
2124
2225logger = logging .getLogger (__name__ )
2326
@@ -28,19 +31,53 @@ async def merge(
2831 new_user_id : int ,
2932 old_user_id : int ,
3033) -> "User" :
31- await asyncio .gather (
32- merge_classes (session , new_user_id , old_user_id ),
33- merge_institutions (session , new_user_id , old_user_id ),
34- merge_assistants (session , new_user_id , old_user_id ),
35- merge_threads (session , new_user_id , old_user_id ),
36- merge_lms_users (session , new_user_id , old_user_id ),
37- merge_external_logins (session , new_user_id , old_user_id ),
38- merge_user_files (session , new_user_id , old_user_id ),
39- merge_permissions (client , new_user_id , old_user_id ),
40- )
34+ await merge_db_operations (session , new_user_id , old_user_id )
35+ await merge_permissions (client , new_user_id , old_user_id )
4136 return await merge_users (session , new_user_id , old_user_id )
4237
4338
39+ async def merge_db_operations (
40+ session : AsyncSession ,
41+ new_user_id : int ,
42+ old_user_id : int ,
43+ ):
44+ await merge_classes (session , new_user_id , old_user_id )
45+ await merge_institutions (session , new_user_id , old_user_id )
46+ await merge_assistants (session , new_user_id , old_user_id )
47+ await merge_threads (session , new_user_id , old_user_id )
48+ await merge_lms_users (session , new_user_id , old_user_id )
49+ await merge_external_logins (session , new_user_id , old_user_id )
50+ await merge_user_files (session , new_user_id , old_user_id )
51+
52+
53+ async def merge_missing_permissions (
54+ session : AsyncSession ,
55+ client : OpenFgaAuthzClient ,
56+ stmt : Select ,
57+ obj_type : str ,
58+ rel : str ,
59+ new_user_id : int ,
60+ ) -> None :
61+ result = await session .execute (stmt )
62+ grants : list [Relation ] = []
63+ revokes : list [Relation ] = []
64+
65+ async def process_row (row ):
66+ old_user_ids = await client .list_entities (f"{ obj_type } :{ row [0 ]} " , rel , "user" )
67+
68+ grants .extend ([(f"user:{ new_user_id } " , rel , f"{ obj_type } :{ row [0 ]} " )])
69+ revokes .extend (
70+ [
71+ (f"user:{ old_id } " , rel , f"{ obj_type } :{ row [0 ]} " )
72+ for old_id in old_user_ids
73+ if old_id != new_user_id
74+ ]
75+ )
76+
77+ await asyncio .gather (* (process_row (row ) for row in result ))
78+ await client .write_safe (grant = grants , revoke = revokes )
79+
80+
4481async def merge_classes (
4582 session : AsyncSession , new_user_id : int , old_user_id : int
4683) -> None :
@@ -95,6 +132,15 @@ async def merge_assistants(
95132 await session .execute (stmt )
96133
97134
135+ async def merge_missing_assistant_permissions (
136+ client : OpenFgaAuthzClient , session : AsyncSession , new_user_id : int
137+ ) -> None :
138+ stmt = select (Assistant .id ).where (Assistant .creator_id == new_user_id )
139+ await merge_missing_permissions (
140+ session , client , stmt , "assistant" , "owner" , new_user_id
141+ )
142+
143+
98144async def merge_threads (
99145 session : AsyncSession , new_user_id : int , old_user_id : int
100146) -> None :
@@ -106,6 +152,18 @@ async def merge_threads(
106152 await session .execute (stmt )
107153
108154
155+ async def merge_missing_thread_permissions (
156+ client : OpenFgaAuthzClient , session : AsyncSession , new_user_id : int
157+ ) -> None :
158+ # Working assumption is that only one persion is associated with a thread, as we currently don't have multiparty threads
159+ stmt = select (user_thread_association .c .thread_id ).where (
160+ user_thread_association .c .user_id == new_user_id
161+ )
162+ await merge_missing_permissions (
163+ session , client , stmt , "thread" , "party" , new_user_id
164+ )
165+
166+
109167async def merge_lms_users (
110168 session : AsyncSession , new_user_id : int , old_user_id : int
111169) -> None :
@@ -148,6 +206,28 @@ async def merge_user_files(
148206 await session .execute (stmt )
149207
150208
209+ async def merge_missing_user_file_permissions (
210+ client : OpenFgaAuthzClient , session : AsyncSession , new_user_id : int
211+ ) -> None :
212+ stmt = select (File .id ).where (
213+ and_ (File .uploader_id == new_user_id , File .private .is_ (True ))
214+ )
215+ await merge_missing_permissions (
216+ session , client , stmt , "user_file" , "owner" , new_user_id
217+ )
218+
219+
220+ async def merge_missing_class_file_permissions (
221+ client : OpenFgaAuthzClient , session : AsyncSession , new_user_id : int
222+ ) -> None :
223+ stmt = select (File .id ).where (
224+ and_ (File .uploader_id == new_user_id , File .private .is_ (False ))
225+ )
226+ await merge_missing_permissions (
227+ session , client , stmt , "class_file" , "owner" , new_user_id
228+ )
229+
230+
151231def get_types () -> list [str ]:
152232 """Get a list of object types used in the authz model."""
153233 with open (config .authz .driver .model_config ) as f :
@@ -180,14 +260,23 @@ async def list_all_permissions(
180260 return all_relations
181261
182262
263+ async def get_merged_user_tuples (
264+ session : AsyncSession ,
265+ ) -> AsyncGenerator [MergedUserTuple , None ]:
266+ stmt = select (
267+ user_merge_association .c .user_id , user_merge_association .c .merged_user_id
268+ )
269+ result = await session .execute (stmt )
270+ for row in result :
271+ yield MergedUserTuple (current_user_id = row [0 ], merged_user_id = row [1 ])
272+
273+
183274async def merge_permissions (
184275 client : OpenFgaAuthzClient , new_user_id : int , old_user_id : int
185276) -> None :
186277 logging .info (f"Merging permissions for { old_user_id } into { new_user_id } " )
187278 old_permissions = await list_all_permissions (client , old_user_id )
188279 new_permissions = [(f"user:{ new_user_id } " , r , o ) for _ , r , o in old_permissions ]
189- logging .info (f"Revoking { old_permissions } " )
190- logging .info (f"Granting { new_permissions } " )
191280 await client .write_safe (grant = new_permissions , revoke = old_permissions )
192281
193282
@@ -196,33 +285,42 @@ async def merge_users(
196285) -> "User" :
197286 old_user = await User .get_by_id (session , old_user_id )
198287 new_user = await User .get_by_id (session , new_user_id )
288+ if not new_user :
289+ raise ValueError (f"New user { new_user_id } not found." )
199290 logging .info (f"Merging user { old_user_id } into { new_user_id } " )
200- logging .info (
201- f"Old user: { old_user .id } , { old_user .email } , { old_user .state } , { 'Super admin' if old_user .super_admin else 'Not super admin' } "
202- )
203- match old_user .state :
204- case "verified" :
205- new_user .state = (
206- "verified" if new_user .state != "banned" else new_user .state
207- )
208- case "banned" :
209- new_user .state = "banned"
210- case _:
211- pass
212-
213- new_user .super_admin = new_user .super_admin or old_user .super_admin
214- stmt = (
215- update (user_merge_association )
216- .where (user_merge_association .c .user_id == old_user_id )
217- .values (user_id = new_user_id )
291+ if not old_user :
292+ logging .warning (
293+ f"Old user { old_user_id } not found, continuing with adding the merge tuple only."
294+ )
295+ if old_user :
296+ match old_user .state :
297+ case "verified" :
298+ new_user .state = (
299+ "verified" if new_user .state != "banned" else new_user .state
300+ )
301+ case "banned" :
302+ new_user .state = "banned"
303+ case _:
304+ pass
305+
306+ new_user .super_admin = new_user .super_admin or old_user .super_admin
307+ update_merged_account_tuple_stmt = (
308+ update (user_merge_association )
309+ .where (user_merge_association .c .user_id == old_user_id )
310+ .values (user_id = new_user_id )
311+ )
312+ await session .execute (update_merged_account_tuple_stmt )
313+ delete_old_user_stmt = delete (User ).where (User .id == old_user_id )
314+ await session .execute (delete_old_user_stmt )
315+ add_new_merge_tuple_stmt = (
316+ _get_upsert_stmt (session )(user_merge_association )
317+ .values (user_id = new_user_id , merged_user_id = old_user_id )
318+ .on_conflict_do_nothing (
319+ index_elements = ["user_id" , "merged_user_id" ],
320+ )
218321 )
219- await session .execute (stmt )
220- stmt_ = delete (User ).where (User .id == old_user_id )
221- await session .execute (stmt_ )
322+ await session .execute (add_new_merge_tuple_stmt )
222323 session .add (new_user )
223324 await session .flush ()
224325 await session .refresh (new_user )
225- logging .info (
226- f"New user: { new_user .id } , { new_user .email } , { new_user .state } , { 'Super admin' if new_user .super_admin else 'Not super admin' } "
227- )
228326 return new_user
0 commit comments