99import sqlalchemy as sa
1010
1111from ai .backend .common .data .permission .types import EntityType
12+ from ai .backend .manager .errors .repository import UnsupportedCompositePrimaryKeyError
1213from ai .backend .manager .models .base import Base
1314from ai .backend .manager .models .rbac_models .association_scopes_entities import (
1415 AssociationScopesEntitiesRow ,
2627class CascadeChild (ABC ):
2728 """A child table whose rows must be deleted before the parent's prune.
2829
29- Used for simple FK cascades. Each cascade DELETE runs as::
30+ Used for simple FK cascades. ``execute_pruner`` first locks and
31+ materializes the parent target IDs once, then issues each cascade
32+ DELETE as::
3033
31- DELETE FROM <row_class> WHERE <parent_id_column>
32- IN (SELECT <parent pk> FROM <parent>
33- WHERE <prune_condition AND conditions>)
34+ DELETE FROM <row_class> WHERE <parent_id_column> IN (<target_ids>)
35+
36+ where ``<target_ids>`` is the list returned from the single
37+ ``SELECT pk FOR UPDATE`` against the parent table.
3438
3539 Polymorphic / cross-cutting cleanups (e.g., RBAC associations) are not
3640 handled here — see :meth:`PrunerSpec.entity_type` for that.
@@ -94,18 +98,13 @@ class PrunerSpec[TRow: Base](ABC):
9498 def row_class (cls ) -> type [TRow ]:
9599 """ORM Row class for the parent entity table.
96100
97- Example:
98- return SessionRow
99- """
100- raise NotImplementedError
101-
102- @classmethod
103- @abstractmethod
104- def returning_id (cls ) -> Any :
105- """Primary-key column for the parent's ``DELETE ... RETURNING``.
101+ The single-column primary key is derived from
102+ ``row_class().__table__.primary_key`` by ``execute_pruner``;
103+ composite-PK tables are rejected with
104+ :class:`UnsupportedCompositePrimaryKeyError`.
106105
107106 Example:
108- return SessionRow.id
107+ return SessionRow
109108 """
110109 raise NotImplementedError
111110
@@ -175,11 +174,19 @@ async def execute_pruner[TRow: Base](
175174 PrunerResult with the count and PK list of deleted parent rows.
176175
177176 Raises:
178- RepositoryIntegrityError: If any DELETE violates a database constraint.
177+ UnsupportedCompositePrimaryKeyError: If the parent table has a
178+ composite primary key.
179+ RepositoryIntegrityError: If any DELETE (cascade, RBAC, or parent)
180+ violates a database constraint.
179181 """
180182 cls = type (spec )
181183 table = cls .row_class ().__table__
182- pk_col = cls .returning_id ()
184+ pk_columns = list (table .primary_key .columns )
185+ if len (pk_columns ) != 1 :
186+ raise UnsupportedCompositePrimaryKeyError (
187+ f"PrunerSpec only supports single-column primary keys (table: { table .name } )" ,
188+ )
189+ pk_col = pk_columns [0 ]
183190
184191 where = cls .prune_condition ()
185192 for f in spec .conditions :
@@ -190,24 +197,24 @@ async def execute_pruner[TRow: Base](
190197 if not target_ids :
191198 return PrunerResult (count = 0 , ids = [])
192199
193- for child in spec .cascade :
194- ccls = type (child )
195- cascade_table = ccls .row_class ().__table__
196- await db_sess .execute (
197- sa .delete (cascade_table ).where (ccls .parent_id_column ().in_ (target_ids ))
198- )
199-
200200 rbac_entity_type = cls .entity_type ()
201- if spec .cascade_rbac and rbac_entity_type is not None :
202- await db_sess .execute (
203- sa .delete (AssociationScopesEntitiesRow ).where (
204- AssociationScopesEntitiesRow .entity_type == rbac_entity_type ,
205- AssociationScopesEntitiesRow .entity_id .in_ ([str (i ) for i in target_ids ]),
201+ try :
202+ for child in spec .cascade :
203+ ccls = type (child )
204+ cascade_table = ccls .row_class ().__table__
205+ await db_sess .execute (
206+ sa .delete (cascade_table ).where (ccls .parent_id_column ().in_ (target_ids ))
206207 )
207- )
208208
209- stmt = sa .delete (table ).where (pk_col .in_ (target_ids )).returning (pk_col )
210- try :
209+ if spec .cascade_rbac and rbac_entity_type is not None :
210+ await db_sess .execute (
211+ sa .delete (AssociationScopesEntitiesRow ).where (
212+ AssociationScopesEntitiesRow .entity_type == rbac_entity_type ,
213+ AssociationScopesEntitiesRow .entity_id .in_ ([str (i ) for i in target_ids ]),
214+ )
215+ )
216+
217+ stmt = sa .delete (table ).where (pk_col .in_ (target_ids )).returning (pk_col )
211218 deleted = list ((await db_sess .scalars (stmt )).all ())
212219 except sa .exc .IntegrityError as e :
213220 raise parse_integrity_error (e ) from e
0 commit comments