Skip to content

Commit 8424db0

Browse files
committed
Add composite primary key support
1 parent 7d6b511 commit 8424db0

File tree

9 files changed

+148
-59
lines changed

9 files changed

+148
-59
lines changed

sqlalchemy_crud_plus/crud.py

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, model: Type[Model]):
2626
self.model = model
2727
self.primary_key = self._get_primary_key()
2828

29-
def _get_primary_key(self) -> Column:
29+
def _get_primary_key(self) -> Column | list[Column]:
3030
"""
3131
Dynamically retrieve the primary key column(s) for the model.
3232
"""
@@ -35,7 +35,21 @@ def _get_primary_key(self) -> Column:
3535
if len(primary_key) == 1:
3636
return primary_key[0]
3737
else:
38-
raise CompositePrimaryKeysError('Composite primary keys are not supported')
38+
return list(primary_key)
39+
40+
def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]:
41+
"""
42+
Get the primary key filter(s).
43+
44+
:param pk:
45+
:return:
46+
"""
47+
if isinstance(self.primary_key, list):
48+
if len(pk) != len(self.primary_key):
49+
raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key')
50+
return [column == value for column, value in zip(self.primary_key, pk)]
51+
else:
52+
return [self.primary_key == pk]
3953

4054
async def create_model(
4155
self,
@@ -55,10 +69,7 @@ async def create_model(
5569
:param kwargs: Additional model data not included in the pydantic schema.
5670
:return:
5771
"""
58-
if not kwargs:
59-
ins = self.model(**obj.model_dump())
60-
else:
61-
ins = self.model(**obj.model_dump(), **kwargs)
72+
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
6273

6374
session.add(ins)
6475

@@ -89,10 +100,7 @@ async def create_models(
89100
"""
90101
ins_list = []
91102
for obj in objs:
92-
if not kwargs:
93-
ins = self.model(**obj.model_dump())
94-
else:
95-
ins = self.model(**obj.model_dump(), **kwargs)
103+
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
96104
ins_list.append(ins)
97105

98106
session.add_all(ins_list)
@@ -118,12 +126,12 @@ async def count(
118126
:param kwargs: Query expressions.
119127
:return:
120128
"""
121-
filter_list = list(whereclause)
129+
filters = list(whereclause)
122130

123131
if kwargs:
124-
filter_list.extend(parse_filters(self.model, **kwargs))
132+
filters.extend(parse_filters(self.model, **kwargs))
125133

126-
stmt = select(func.count()).select_from(self.model).where(*filter_list)
134+
stmt = select(func.count()).select_from(self.model).where(*filters)
127135
query = await session.execute(stmt)
128136
total_count = query.scalar()
129137
return total_count if total_count is not None else 0
@@ -154,21 +162,20 @@ async def exists(
154162
async def select_model(
155163
self,
156164
session: AsyncSession,
157-
pk: int,
165+
pk: Any | Sequence[Any],
158166
*whereclause: ColumnExpressionArgument[bool],
159167
) -> Model | None:
160168
"""
161-
Query by ID
169+
Query by primary key(s)
162170
163171
:param session: The SQLAlchemy async session.
164-
:param pk: The database primary key value.
172+
:param pk: Single value for simple primary key, or tuple for composite primary key.
165173
:param whereclause: The WHERE clauses to apply to the query.
166174
:return:
167175
"""
168-
filter_list = list(whereclause)
169-
_filters = [self.primary_key == pk]
170-
_filters.extend(filter_list)
171-
stmt = select(self.model).where(*_filters)
176+
filters = self._get_pk_filter(pk)
177+
filters + list(whereclause)
178+
stmt = select(self.model).where(*filters)
172179
query = await session.execute(stmt)
173180
return query.scalars().first()
174181

@@ -186,10 +193,8 @@ async def select_model_by_column(
186193
:param kwargs: Query expressions.
187194
:return:
188195
"""
189-
filter_list = list(whereclause)
190-
_filters = parse_filters(self.model, **kwargs)
191-
_filters.extend(filter_list)
192-
stmt = select(self.model).where(*_filters)
196+
filters = parse_filters(self.model, **kwargs) + list(whereclause)
197+
stmt = select(self.model).where(*filters)
193198
query = await session.execute(stmt)
194199
return query.scalars().first()
195200

@@ -201,10 +206,8 @@ async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -
201206
:param kwargs: Query expressions.
202207
:return:
203208
"""
204-
filter_list = list(whereclause)
205-
_filters = parse_filters(self.model, **kwargs)
206-
_filters.extend(filter_list)
207-
stmt = select(self.model).where(*_filters)
209+
filters = parse_filters(self.model, **kwargs) + list(whereclause)
210+
stmt = select(self.model).where(*filters)
208211
return stmt
209212

210213
async def select_order(
@@ -270,7 +273,7 @@ async def select_models_order(
270273
async def update_model(
271274
self,
272275
session: AsyncSession,
273-
pk: int,
276+
pk: Any | Sequence[Any],
274277
obj: UpdateSchema | dict[str, Any],
275278
flush: bool = False,
276279
commit: bool = False,
@@ -280,21 +283,17 @@ async def update_model(
280283
Update an instance by model's primary key
281284
282285
:param session: The SQLAlchemy async session.
283-
:param pk: The database primary key value.
286+
:param pk: Single value for simple primary key, or tuple for composite primary key.
284287
:param obj: A pydantic schema or dictionary containing the update data
285288
:param flush: If `True`, flush all object changes to the database. Default is `False`.
286289
:param commit: If `True`, commits the transaction immediately. Default is `False`.
287290
:param kwargs: Additional model data not included in the pydantic schema.
288291
:return:
289292
"""
290-
if isinstance(obj, dict):
291-
instance_data = obj
292-
else:
293-
instance_data = obj.model_dump(exclude_unset=True)
294-
if kwargs:
295-
instance_data.update(kwargs)
296-
297-
stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
293+
filters = self._get_pk_filter(pk)
294+
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
295+
instance_data.update(kwargs)
296+
stmt = update(self.model).where(*filters).values(**instance_data)
298297
result = await session.execute(stmt)
299298

300299
if flush:
@@ -325,15 +324,13 @@ async def update_model_by_column(
325324
:return:
326325
"""
327326
filters = parse_filters(self.model, **kwargs)
327+
328328
total_count = await self.count(session, *filters)
329329
if not allow_multiple and total_count > 1:
330330
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
331-
if isinstance(obj, dict):
332-
instance_data = obj
333-
else:
334-
instance_data = obj.model_dump(exclude_unset=True)
335331

336-
stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore
332+
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
333+
stmt = update(self.model).where(*filters).values(**instance_data)
337334
result = await session.execute(stmt)
338335

339336
if flush:
@@ -346,20 +343,22 @@ async def update_model_by_column(
346343
async def delete_model(
347344
self,
348345
session: AsyncSession,
349-
pk: int,
346+
pk: Any | Sequence[Any],
350347
flush: bool = False,
351348
commit: bool = False,
352349
) -> int:
353350
"""
354351
Delete an instance by model's primary key
355352
356353
:param session: The SQLAlchemy async session.
357-
:param pk: The database primary key value.
354+
:param pk: Single value for simple primary key, or tuple for composite primary key.
358355
:param flush: If `True`, flush all object changes to the database. Default is `False`.
359356
:param commit: If `True`, commits the transaction immediately. Default is `False`.
360357
:return:
361358
"""
362-
stmt = delete(self.model).where(self.primary_key == pk)
359+
filters = self._get_pk_filter(pk)
360+
361+
stmt = delete(self.model).where(*filters)
363362
result = await session.execute(stmt)
364363

365364
if flush:
@@ -392,14 +391,16 @@ async def delete_model_by_column(
392391
:return:
393392
"""
394393
filters = parse_filters(self.model, **kwargs)
394+
395395
total_count = await self.count(session, *filters)
396396
if not allow_multiple and total_count > 1:
397397
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
398-
if logical_deletion:
399-
deleted_flag = {deleted_flag_column: True}
400-
stmt = update(self.model).where(*filters).values(**deleted_flag)
401-
else:
402-
stmt = delete(self.model).where(*filters)
398+
399+
stmt = (
400+
update(self.model).where(*filters).values(**{deleted_flag_column: True})
401+
if logical_deletion
402+
else delete(self.model).where(*filters)
403+
)
403404

404405
result = await session.execute(stmt)
405406

tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
77

8-
from tests.model import Base, Ins
8+
from tests.model import Base, Ins, InsPks
99

1010
_async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True)
1111
_async_session = async_sessionmaker(_async_engine, autoflush=False, expire_on_commit=False)
@@ -29,3 +29,12 @@ async def create_test_model():
2929
async with _async_session.begin() as session:
3030
data = [Ins(name=f'name_{i}') for i in range(1, 10)]
3131
session.add_all(data)
32+
33+
34+
@pytest_asyncio.fixture
35+
async def create_test_model_pks():
36+
async with _async_session.begin() as session:
37+
data = [InsPks(id=i, name=f'name_{i}', sex='men') for i in range(1, 5)]
38+
session.add_all(data)
39+
data = [InsPks(id=i, name=f'name_{i}', sex='women') for i in range(6, 10)]
40+
session.add_all(data)

tests/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,14 @@ class Ins(Base):
2020
del_flag: Mapped[bool] = mapped_column(default=False)
2121
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
2222
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)
23+
24+
25+
class InsPks(Base):
26+
__tablename__ = 'ins_pks'
27+
28+
id: Mapped[int] = mapped_column(primary_key=True, index=True)
29+
name: Mapped[str] = mapped_column(String(64))
30+
sex: Mapped[str] = mapped_column(String(16), primary_key=True, index=True)
31+
del_flag: Mapped[bool] = mapped_column(default=False)
32+
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
33+
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)

tests/schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@
55

66
class ModelTest(BaseModel):
77
name: str
8+
9+
10+
class ModelTestPks(BaseModel):
11+
id: int
12+
name: str
13+
sex: str

tests/test_create.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
from random import choice
4+
35
import pytest
46

57
from sqlalchemy import select
68

79
from sqlalchemy_crud_plus import CRUDPlus
8-
from tests.model import Ins
9-
from tests.schema import ModelTest
10+
from tests.model import Ins, InsPks
11+
from tests.schema import ModelTest, ModelTestPks
1012

1113

1214
@pytest.mark.asyncio
@@ -34,3 +36,30 @@ async def test_create_models(async_db_session):
3436
for i in range(1, 10):
3537
query = await session.scalar(select(Ins).where(Ins.id == i))
3638
assert query.name == f'name_{i}'
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_create_model_pks(async_db_session):
43+
async with async_db_session.begin() as session:
44+
crud = CRUDPlus(InsPks)
45+
for i in range(1, 10):
46+
data = ModelTestPks(id=i, name=f'name_{i}', sex=choice(['men', 'women']))
47+
await crud.create_model(session, data)
48+
async with async_db_session() as session:
49+
for i in range(1, 10):
50+
query = await session.scalar(select(InsPks).where(InsPks.id == i))
51+
assert query.name == f'name_{i}'
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_create_models_pks(async_db_session):
56+
async with async_db_session.begin() as session:
57+
crud = CRUDPlus(InsPks)
58+
data = []
59+
for i in range(1, 10):
60+
data.append(ModelTestPks(id=i, name=f'name_{i}', sex=choice(['men', 'women'])))
61+
await crud.create_models(session, data)
62+
async with async_db_session() as session:
63+
for i in range(1, 10):
64+
query = await session.scalar(select(InsPks).where(InsPks.id == i))
65+
assert query.name == f'name_{i}'

tests/test_delete.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from sqlalchemy_crud_plus import CRUDPlus
6-
from tests.model import Ins
6+
from tests.model import Ins, InsPks
77

88

99
@pytest.mark.asyncio
@@ -14,6 +14,14 @@ async def test_delete_model(create_test_model, async_db_session):
1414
assert result == 1
1515

1616

17+
@pytest.mark.asyncio
18+
async def test_delete_model_pks(create_test_model_pks, async_db_session):
19+
async with async_db_session.begin() as session:
20+
crud = CRUDPlus(InsPks)
21+
result = await crud.delete_model(session, (1, 'men'))
22+
assert result == 1
23+
24+
1725
@pytest.mark.asyncio
1826
async def test_delete_model_by_column(create_test_model, async_db_session):
1927
async with async_db_session.begin() as session:

tests/test_select.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy import Select
66

77
from sqlalchemy_crud_plus import CRUDPlus
8-
from tests.model import Ins
8+
from tests.model import Ins, InsPks
99

1010

1111
@pytest.mark.asyncio
@@ -62,6 +62,18 @@ async def test_select_model(create_test_model, async_db_session):
6262
assert result.name == f'name_{i}'
6363

6464

65+
@pytest.mark.asyncio
66+
async def test_select_model_pks(create_test_model_pks, async_db_session):
67+
async with async_db_session() as session:
68+
crud = CRUDPlus(InsPks)
69+
for i in range(1, 5):
70+
result = await crud.select_model(session, (i, 'men'))
71+
assert result.name == f'name_{i}'
72+
for i in range(6, 10):
73+
result = await crud.select_model(session, (i, 'women'))
74+
assert result.name == f'name_{i}'
75+
76+
6577
@pytest.mark.asyncio
6678
async def test_select_model_by_column(create_test_model, async_db_session):
6779
async with async_db_session() as session:

0 commit comments

Comments
 (0)