Skip to content

Commit f23a963

Browse files
fregataaclaude
andcommitted
refactor(BA-5936): apply Copilot review feedback
- Wrap all DELETEs (cascade, RBAC, parent) in a single try/except for IntegrityError so cascade failures also surface as RepositoryIntegrityError instead of raw SQLAlchemy errors. - Drop PrunerSpec.returning_id() — derive the parent PK column directly from row_class().__table__.primary_key, mirroring the Purger pattern. Reject composite-PK tables with UnsupportedCompositePrimaryKeyError. - Update CascadeChild docstring to reflect the materialized-id-list approach (the previous SQL-subquery wording was stale). - Tighten test_no_cascade_with_fk_violation_raises to expect ForeignKeyViolationError, also covering the parse_integrity_error path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 301dd9c commit f23a963

2 files changed

Lines changed: 46 additions & 37 deletions

File tree

src/ai/backend/manager/repositories/base/pruner.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sqlalchemy as sa
1010

1111
from ai.backend.common.data.permission.types import EntityType
12+
from ai.backend.manager.errors.repository import UnsupportedCompositePrimaryKeyError
1213
from ai.backend.manager.models.base import Base
1314
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
1415
AssociationScopesEntitiesRow,
@@ -26,11 +27,14 @@
2627
class CascadeChild(ABC):
2728
"""A child table whose rows must be deleted before the parent's prune.
2829
29-
Used for simple FK cascades. Each cascade DELETE runs as::
30+
Used for simple FK cascades. ``execute_pruner`` first locks and
31+
materializes the parent target IDs once, then issues each cascade
32+
DELETE as::
3033
31-
DELETE FROM <row_class> WHERE <parent_id_column>
32-
IN (SELECT <parent pk> FROM <parent>
33-
WHERE <prune_condition AND conditions>)
34+
DELETE FROM <row_class> WHERE <parent_id_column> IN (<target_ids>)
35+
36+
where ``<target_ids>`` is the list returned from the single
37+
``SELECT pk FOR UPDATE`` against the parent table.
3438
3539
Polymorphic / cross-cutting cleanups (e.g., RBAC associations) are not
3640
handled here — see :meth:`PrunerSpec.entity_type` for that.
@@ -94,18 +98,13 @@ class PrunerSpec[TRow: Base](ABC):
9498
def row_class(cls) -> type[TRow]:
9599
"""ORM Row class for the parent entity table.
96100
97-
Example:
98-
return SessionRow
99-
"""
100-
raise NotImplementedError
101-
102-
@classmethod
103-
@abstractmethod
104-
def returning_id(cls) -> Any:
105-
"""Primary-key column for the parent's ``DELETE ... RETURNING``.
101+
The single-column primary key is derived from
102+
``row_class().__table__.primary_key`` by ``execute_pruner``;
103+
composite-PK tables are rejected with
104+
:class:`UnsupportedCompositePrimaryKeyError`.
106105
107106
Example:
108-
return SessionRow.id
107+
return SessionRow
109108
"""
110109
raise NotImplementedError
111110

@@ -175,11 +174,19 @@ async def execute_pruner[TRow: Base](
175174
PrunerResult with the count and PK list of deleted parent rows.
176175
177176
Raises:
178-
RepositoryIntegrityError: If any DELETE violates a database constraint.
177+
UnsupportedCompositePrimaryKeyError: If the parent table has a
178+
composite primary key.
179+
RepositoryIntegrityError: If any DELETE (cascade, RBAC, or parent)
180+
violates a database constraint.
179181
"""
180182
cls = type(spec)
181183
table = cls.row_class().__table__
182-
pk_col = cls.returning_id()
184+
pk_columns = list(table.primary_key.columns)
185+
if len(pk_columns) != 1:
186+
raise UnsupportedCompositePrimaryKeyError(
187+
f"PrunerSpec only supports single-column primary keys (table: {table.name})",
188+
)
189+
pk_col = pk_columns[0]
183190

184191
where = cls.prune_condition()
185192
for f in spec.conditions:
@@ -190,24 +197,24 @@ async def execute_pruner[TRow: Base](
190197
if not target_ids:
191198
return PrunerResult(count=0, ids=[])
192199

193-
for child in spec.cascade:
194-
ccls = type(child)
195-
cascade_table = ccls.row_class().__table__
196-
await db_sess.execute(
197-
sa.delete(cascade_table).where(ccls.parent_id_column().in_(target_ids))
198-
)
199-
200200
rbac_entity_type = cls.entity_type()
201-
if spec.cascade_rbac and rbac_entity_type is not None:
202-
await db_sess.execute(
203-
sa.delete(AssociationScopesEntitiesRow).where(
204-
AssociationScopesEntitiesRow.entity_type == rbac_entity_type,
205-
AssociationScopesEntitiesRow.entity_id.in_([str(i) for i in target_ids]),
201+
try:
202+
for child in spec.cascade:
203+
ccls = type(child)
204+
cascade_table = ccls.row_class().__table__
205+
await db_sess.execute(
206+
sa.delete(cascade_table).where(ccls.parent_id_column().in_(target_ids))
206207
)
207-
)
208208

209-
stmt = sa.delete(table).where(pk_col.in_(target_ids)).returning(pk_col)
210-
try:
209+
if spec.cascade_rbac and rbac_entity_type is not None:
210+
await db_sess.execute(
211+
sa.delete(AssociationScopesEntitiesRow).where(
212+
AssociationScopesEntitiesRow.entity_type == rbac_entity_type,
213+
AssociationScopesEntitiesRow.entity_id.in_([str(i) for i in target_ids]),
214+
)
215+
)
216+
217+
stmt = sa.delete(table).where(pk_col.in_(target_ids)).returning(pk_col)
211218
deleted = list((await db_sess.scalars(stmt)).all())
212219
except sa.exc.IntegrityError as e:
213220
raise parse_integrity_error(e) from e

tests/unit/manager/repositories/base/test_pruner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlalchemy.dialects.postgresql import UUID as PGUUID
1515

1616
from ai.backend.common.data.permission.types import EntityType, ScopeType
17+
from ai.backend.manager.errors.repository import ForeignKeyViolationError
1718
from ai.backend.manager.models.base import Base
1819
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
1920
AssociationScopesEntitiesRow,
@@ -77,10 +78,6 @@ class TerminatedTestParentPrunerSpec(PrunerSpec[PrunerTestParentRow]):
7778
def row_class(cls) -> type[PrunerTestParentRow]:
7879
return PrunerTestParentRow
7980

80-
@classmethod
81-
def returning_id(cls) -> Any:
82-
return PrunerTestParentRow.id
83-
8481
@classmethod
8582
def prune_condition(cls) -> sa.ColumnElement[bool]:
8683
return PrunerTestParentRow.status == "terminated"
@@ -325,10 +322,15 @@ async def test_cascade_skipped_for_non_terminal_parents(
325322
async def test_no_cascade_with_fk_violation_raises(
326323
self, parent_child_tables: ExtendedAsyncSAEngine
327324
) -> None:
328-
"""Without the cascade, FK constraint blocks the parent DELETE."""
325+
"""Without the cascade, FK constraint blocks the parent DELETE.
326+
327+
Also verifies that ``execute_pruner`` translates the SQLAlchemy
328+
``IntegrityError`` into ``ForeignKeyViolationError`` via
329+
``parse_integrity_error``.
330+
"""
329331
await self._seed_with_children(parent_child_tables)
330332

331-
with pytest.raises(Exception):
333+
with pytest.raises(ForeignKeyViolationError):
332334
async with parent_child_tables.begin_session() as db_sess:
333335
spec = TerminatedTestParentPrunerSpec() # no cascade
334336
await execute_pruner(db_sess, spec)

0 commit comments

Comments
 (0)