Skip to content

Commit fc6dc56

Browse files
authored
Revamp merging.py module (#604)
1 parent d3f212c commit fc6dc56

4 files changed

Lines changed: 245 additions & 45 deletions

File tree

pingpong/__main__.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111
from sqlalchemy import text
1212
from sqlalchemy.ext.asyncio import create_async_engine
1313

14+
from pingpong.merge import (
15+
get_merged_user_tuples,
16+
list_all_permissions,
17+
merge_missing_assistant_permissions,
18+
merge_missing_class_file_permissions,
19+
merge_missing_thread_permissions,
20+
merge_missing_user_file_permissions,
21+
merge_permissions,
22+
merge,
23+
)
24+
1425
from .auth import encode_auth_token
1526
from .bg import get_server
1627
from .canvas import canvas_sync_all
@@ -118,6 +129,79 @@ async def _get_or_create(email) -> int:
118129
webbrowser.open(url)
119130

120131

132+
# This command lists all explicitly granted permissions for a user
133+
@auth.command("list_permissions")
134+
@click.argument("user_id", type=int)
135+
def list_permissions(user_id: int) -> None:
136+
async def _list_permissions() -> None:
137+
await config.authz.driver.init()
138+
async with config.authz.driver.get_client() as c:
139+
perms = await list_all_permissions(c, user_id)
140+
logging.info(f"Permissions for user {user_id}: {perms}")
141+
142+
asyncio.run(_list_permissions())
143+
144+
145+
# This command attempts to merge any outstanding permissions
146+
# from one user to another based on the users_merged_users table
147+
@auth.command("redo_permission_merges")
148+
def users_merge_permissions() -> None:
149+
async def _users_merge_permissions() -> None:
150+
await config.authz.driver.init()
151+
async with config.db.driver.async_session() as session:
152+
async with config.authz.driver.get_client() as c:
153+
logger.info("Merging permissions for all users...")
154+
async for row in get_merged_user_tuples(session):
155+
logging.info(
156+
f"Merging permissions for {row.merged_user_id} into {row.current_user_id}"
157+
)
158+
await merge_permissions(c, row.current_user_id, row.merged_user_id)
159+
160+
asyncio.run(_users_merge_permissions())
161+
162+
163+
# This command attempts to recover any missing permissions for a user
164+
# after a user(s) has/have been merged into said user. This command uses
165+
# fields in the database to infer which permissions the user should have
166+
@auth.command("add_missing_permissions")
167+
@click.argument("new_user_id", type=int)
168+
def add_missing_permissions(new_user_id: int) -> None:
169+
async def _add_missing_permissions() -> None:
170+
await config.authz.driver.init()
171+
async with config.db.driver.async_session() as session:
172+
async with config.authz.driver.get_client() as c:
173+
logger.info(f"Adding missing permissions for user {new_user_id}...")
174+
logger.info("Merging assistant permissions...")
175+
await merge_missing_assistant_permissions(c, session, new_user_id)
176+
logger.info("Merging thread permissions...")
177+
await merge_missing_thread_permissions(c, session, new_user_id)
178+
logger.info("Merging user file permissions...")
179+
await merge_missing_user_file_permissions(c, session, new_user_id)
180+
logger.info("Merging class file permissions...")
181+
await merge_missing_class_file_permissions(c, session, new_user_id)
182+
logger.info("Done!")
183+
184+
asyncio.run(_add_missing_permissions())
185+
186+
187+
# This command attempts to merge all permissions from old_user_id to new_user_id.
188+
# This command can be used if a user has been merged into another user
189+
# and some permissions were not transferred over, or the tuple was not added in users_merged_users.
190+
# In other words, it can be used with `old_user_id`s of users who have already been deleted.
191+
@auth.command("merge_users")
192+
@click.argument("new_user_id", type=int)
193+
@click.argument("old_user_id", type=int)
194+
def merge_users(new_user_id: int, old_user_id: int) -> None:
195+
async def _merge_users() -> None:
196+
await config.authz.driver.init()
197+
async with config.db.driver.async_session() as session:
198+
async with config.authz.driver.get_client() as c:
199+
await merge(session, c, new_user_id, old_user_id)
200+
await session.commit()
201+
202+
asyncio.run(_merge_users())
203+
204+
121205
def _load_alembic(alembic_config="alembic.ini") -> alembic.config.Config:
122206
"""Load the Alembic config."""
123207
al_cfg = alembic.config.Config(alembic_config)

pingpong/authz/openfga.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,29 @@ async def list_entities(self, target: str, relation: str, type_: str) -> List[in
7878
return [int(user.object.id) for user in response.users]
7979

8080
async def read(self, key: ReadRequestTupleKey) -> List[ClientTuple]:
81-
response: ReadResponse = await self._cli.read(key)
81+
client_tuples = []
82+
continuation_token = None
8283

83-
client_tuples = [
84-
ClientTuple(
85-
user=tuple.key.user,
86-
relation=tuple.key.relation,
87-
object=tuple.key.object,
84+
while True:
85+
response: ReadResponse = await self._cli.read(
86+
key, {"continuation_token": continuation_token}
8887
)
89-
for tuple in response.tuples
90-
]
88+
89+
client_tuples.extend(
90+
[
91+
ClientTuple(
92+
user=tuple.key.user,
93+
relation=tuple.key.relation,
94+
object=tuple.key.object,
95+
)
96+
for tuple in response.tuples
97+
]
98+
)
99+
100+
if not response.continuation_token:
101+
break
102+
103+
continuation_token = response.continuation_token
91104

92105
return client_tuples
93106

pingpong/merge.py

Lines changed: 135 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
3-
from sqlalchemy import delete, text, update
3+
from typing import AsyncGenerator
4+
from sqlalchemy import Select, and_, delete, select, text, update
45
from sqlalchemy.ext.asyncio import AsyncSession
56
import logging
67

@@ -14,10 +15,12 @@
1415
User,
1516
UserClassRole,
1617
UserInstitutionRole,
18+
_get_upsert_stmt,
1719
user_thread_association,
1820
user_merge_association,
1921
File,
2022
)
23+
from pingpong.schemas import MergedUserTuple
2124

2225
logger = logging.getLogger(__name__)
2326

@@ -28,19 +31,53 @@ async def merge(
2831
new_user_id: int,
2932
old_user_id: int,
3033
) -> "User":
31-
await asyncio.gather(
32-
merge_classes(session, new_user_id, old_user_id),
33-
merge_institutions(session, new_user_id, old_user_id),
34-
merge_assistants(session, new_user_id, old_user_id),
35-
merge_threads(session, new_user_id, old_user_id),
36-
merge_lms_users(session, new_user_id, old_user_id),
37-
merge_external_logins(session, new_user_id, old_user_id),
38-
merge_user_files(session, new_user_id, old_user_id),
39-
merge_permissions(client, new_user_id, old_user_id),
40-
)
34+
await merge_db_operations(session, new_user_id, old_user_id)
35+
await merge_permissions(client, new_user_id, old_user_id)
4136
return await merge_users(session, new_user_id, old_user_id)
4237

4338

39+
async def merge_db_operations(
40+
session: AsyncSession,
41+
new_user_id: int,
42+
old_user_id: int,
43+
):
44+
await merge_classes(session, new_user_id, old_user_id)
45+
await merge_institutions(session, new_user_id, old_user_id)
46+
await merge_assistants(session, new_user_id, old_user_id)
47+
await merge_threads(session, new_user_id, old_user_id)
48+
await merge_lms_users(session, new_user_id, old_user_id)
49+
await merge_external_logins(session, new_user_id, old_user_id)
50+
await merge_user_files(session, new_user_id, old_user_id)
51+
52+
53+
async def merge_missing_permissions(
54+
session: AsyncSession,
55+
client: OpenFgaAuthzClient,
56+
stmt: Select,
57+
obj_type: str,
58+
rel: str,
59+
new_user_id: int,
60+
) -> None:
61+
result = await session.execute(stmt)
62+
grants: list[Relation] = []
63+
revokes: list[Relation] = []
64+
65+
async def process_row(row):
66+
old_user_ids = await client.list_entities(f"{obj_type}:{row[0]}", rel, "user")
67+
68+
grants.extend([(f"user:{new_user_id}", rel, f"{obj_type}:{row[0]}")])
69+
revokes.extend(
70+
[
71+
(f"user:{old_id}", rel, f"{obj_type}:{row[0]}")
72+
for old_id in old_user_ids
73+
if old_id != new_user_id
74+
]
75+
)
76+
77+
await asyncio.gather(*(process_row(row) for row in result))
78+
await client.write_safe(grant=grants, revoke=revokes)
79+
80+
4481
async def merge_classes(
4582
session: AsyncSession, new_user_id: int, old_user_id: int
4683
) -> None:
@@ -95,6 +132,15 @@ async def merge_assistants(
95132
await session.execute(stmt)
96133

97134

135+
async def merge_missing_assistant_permissions(
136+
client: OpenFgaAuthzClient, session: AsyncSession, new_user_id: int
137+
) -> None:
138+
stmt = select(Assistant.id).where(Assistant.creator_id == new_user_id)
139+
await merge_missing_permissions(
140+
session, client, stmt, "assistant", "owner", new_user_id
141+
)
142+
143+
98144
async def merge_threads(
99145
session: AsyncSession, new_user_id: int, old_user_id: int
100146
) -> None:
@@ -106,6 +152,18 @@ async def merge_threads(
106152
await session.execute(stmt)
107153

108154

155+
async def merge_missing_thread_permissions(
156+
client: OpenFgaAuthzClient, session: AsyncSession, new_user_id: int
157+
) -> None:
158+
# Working assumption is that only one persion is associated with a thread, as we currently don't have multiparty threads
159+
stmt = select(user_thread_association.c.thread_id).where(
160+
user_thread_association.c.user_id == new_user_id
161+
)
162+
await merge_missing_permissions(
163+
session, client, stmt, "thread", "party", new_user_id
164+
)
165+
166+
109167
async def merge_lms_users(
110168
session: AsyncSession, new_user_id: int, old_user_id: int
111169
) -> None:
@@ -148,6 +206,28 @@ async def merge_user_files(
148206
await session.execute(stmt)
149207

150208

209+
async def merge_missing_user_file_permissions(
210+
client: OpenFgaAuthzClient, session: AsyncSession, new_user_id: int
211+
) -> None:
212+
stmt = select(File.id).where(
213+
and_(File.uploader_id == new_user_id, File.private.is_(True))
214+
)
215+
await merge_missing_permissions(
216+
session, client, stmt, "user_file", "owner", new_user_id
217+
)
218+
219+
220+
async def merge_missing_class_file_permissions(
221+
client: OpenFgaAuthzClient, session: AsyncSession, new_user_id: int
222+
) -> None:
223+
stmt = select(File.id).where(
224+
and_(File.uploader_id == new_user_id, File.private.is_(False))
225+
)
226+
await merge_missing_permissions(
227+
session, client, stmt, "class_file", "owner", new_user_id
228+
)
229+
230+
151231
def get_types() -> list[str]:
152232
"""Get a list of object types used in the authz model."""
153233
with open(config.authz.driver.model_config) as f:
@@ -180,14 +260,23 @@ async def list_all_permissions(
180260
return all_relations
181261

182262

263+
async def get_merged_user_tuples(
264+
session: AsyncSession,
265+
) -> AsyncGenerator[MergedUserTuple, None]:
266+
stmt = select(
267+
user_merge_association.c.user_id, user_merge_association.c.merged_user_id
268+
)
269+
result = await session.execute(stmt)
270+
for row in result:
271+
yield MergedUserTuple(current_user_id=row[0], merged_user_id=row[1])
272+
273+
183274
async def merge_permissions(
184275
client: OpenFgaAuthzClient, new_user_id: int, old_user_id: int
185276
) -> None:
186277
logging.info(f"Merging permissions for {old_user_id} into {new_user_id}")
187278
old_permissions = await list_all_permissions(client, old_user_id)
188279
new_permissions = [(f"user:{new_user_id}", r, o) for _, r, o in old_permissions]
189-
logging.info(f"Revoking {old_permissions}")
190-
logging.info(f"Granting {new_permissions}")
191280
await client.write_safe(grant=new_permissions, revoke=old_permissions)
192281

193282

@@ -196,33 +285,42 @@ async def merge_users(
196285
) -> "User":
197286
old_user = await User.get_by_id(session, old_user_id)
198287
new_user = await User.get_by_id(session, new_user_id)
288+
if not new_user:
289+
raise ValueError(f"New user {new_user_id} not found.")
199290
logging.info(f"Merging user {old_user_id} into {new_user_id}")
200-
logging.info(
201-
f"Old user: {old_user.id}, {old_user.email}, {old_user.state}, {'Super admin' if old_user.super_admin else 'Not super admin'}"
202-
)
203-
match old_user.state:
204-
case "verified":
205-
new_user.state = (
206-
"verified" if new_user.state != "banned" else new_user.state
207-
)
208-
case "banned":
209-
new_user.state = "banned"
210-
case _:
211-
pass
212-
213-
new_user.super_admin = new_user.super_admin or old_user.super_admin
214-
stmt = (
215-
update(user_merge_association)
216-
.where(user_merge_association.c.user_id == old_user_id)
217-
.values(user_id=new_user_id)
291+
if not old_user:
292+
logging.warning(
293+
f"Old user {old_user_id} not found, continuing with adding the merge tuple only."
294+
)
295+
if old_user:
296+
match old_user.state:
297+
case "verified":
298+
new_user.state = (
299+
"verified" if new_user.state != "banned" else new_user.state
300+
)
301+
case "banned":
302+
new_user.state = "banned"
303+
case _:
304+
pass
305+
306+
new_user.super_admin = new_user.super_admin or old_user.super_admin
307+
update_merged_account_tuple_stmt = (
308+
update(user_merge_association)
309+
.where(user_merge_association.c.user_id == old_user_id)
310+
.values(user_id=new_user_id)
311+
)
312+
await session.execute(update_merged_account_tuple_stmt)
313+
delete_old_user_stmt = delete(User).where(User.id == old_user_id)
314+
await session.execute(delete_old_user_stmt)
315+
add_new_merge_tuple_stmt = (
316+
_get_upsert_stmt(session)(user_merge_association)
317+
.values(user_id=new_user_id, merged_user_id=old_user_id)
318+
.on_conflict_do_nothing(
319+
index_elements=["user_id", "merged_user_id"],
320+
)
218321
)
219-
await session.execute(stmt)
220-
stmt_ = delete(User).where(User.id == old_user_id)
221-
await session.execute(stmt_)
322+
await session.execute(add_new_merge_tuple_stmt)
222323
session.add(new_user)
223324
await session.flush()
224325
await session.refresh(new_user)
225-
logging.info(
226-
f"New user: {new_user.id}, {new_user.email}, {new_user.state}, {'Super admin' if new_user.super_admin else 'Not super admin'}"
227-
)
228326
return new_user

pingpong/schemas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def has_real_name(self) -> bool:
136136
return bool(self.display_name or self.first_name or self.last_name)
137137

138138

139+
class MergedUserTuple(BaseModel):
140+
current_user_id: int
141+
merged_user_id: int
142+
143+
139144
class User(BaseModel, UserNameMixin):
140145
id: int
141146
state: UserState

0 commit comments

Comments
 (0)