Skip to content

Add composite primary key support #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions docs/advanced/primary_key.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@

由于在 python 内部 `id` 为关键字,因此,我们设定默认主键入参为 `pk`。这仅用于函数入参,并不要求模型主键必须定义为 `pk`

```py title="e.g." hl_lines="2"
async def delete(self, db: AsyncSession, primary_key: int) -> int:
return self.delete_model(db, pk=primary_key)
```

## 主键定义

!!! tip 自动主键

我们在 SQLAlchemy CRUD Plus 内部通过 [inspect()](https://docs.sqlalchemy.org/en/20/core/inspection.html) 自动搜索表主键,
而非强制绑定主键列必须命名为 `id`

```py title="e.g." hl_lines="4"
## 单个主键

```py title="e.g."
class ModelIns(Base):
# define primary_key
primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)


class ModelIns2(Base):
# define primary_key
primary_key: Mapped[str] = mapped_column(primary_key=True, index=True)
```

## 复合主键

```python title="e.g."
class ModelIns(Base):
primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)
primary_key2: Mapped[str] = mapped_column(primary_key=True, index=True)
```
14 changes: 7 additions & 7 deletions docs/usage/delete_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def delete_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
flush: bool = False,
commit: bool = False,
) -> int:
```

**Parameters:**

| Name | Type | Description | Default |
|---------|--------------|----------------------------------|---------|
| session | AsyncSession | 数据库会话 | 必填 |
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
| commit | bool | [提交](../advanced/commit.md) | `False` |
| Name | Type | Description | Default |
|---------|--------------------------|----------------------------------|---------|
| session | AsyncSession | 数据库会话 | 必填 |
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
| commit | bool | [提交](../advanced/commit.md) | `False` |

**Returns:**

Expand Down
10 changes: 5 additions & 5 deletions docs/usage/select_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def select_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
*whereclause: ColumnExpressionArgument[bool],
) -> Model | None:
```

**Parameters:**

| Name | Type | Description | Default |
|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------|
| session | AsyncSession | 数据库会话 | 必填 |
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
| Name | Type | Description | Default |
|--------------|----------------------------------|-----------------------------------------------------------------------------------------------------|---------|
| session | AsyncSession | 数据库会话 | 必填 |
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | |

**Returns:**
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/update_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def update_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
obj: UpdateSchema | dict[str, Any],
flush: bool = False,
commit: bool = False,
Expand All @@ -44,7 +44,7 @@ async def update_model(
| Name | Type | Description | Default |
|---------|-------------------------------|----------------------------------|---------|
| session | AsyncSession | 数据库会话 | 必填 |
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
| obj | `TypeVar `\|` dict[str, Any]` | 更新数据参数 | 必填 |
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
| commit | bool | [提交](../advanced/commit.md) | `False` |
Expand Down
103 changes: 52 additions & 51 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, model: Type[Model]):
self.model = model
self.primary_key = self._get_primary_key()

def _get_primary_key(self) -> Column:
def _get_primary_key(self) -> Column | list[Column]:
"""
Dynamically retrieve the primary key column(s) for the model.
"""
Expand All @@ -35,7 +35,21 @@ def _get_primary_key(self) -> Column:
if len(primary_key) == 1:
return primary_key[0]
else:
raise CompositePrimaryKeysError('Composite primary keys are not supported')
return list(primary_key)

def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]:
"""
Get the primary key filter(s).

:param pk: Single value for simple primary key, or tuple for composite primary key.
:return:
"""
if isinstance(self.primary_key, list):
if len(pk) != len(self.primary_key):
raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key')
return [column == value for column, value in zip(self.primary_key, pk)]
else:
return [self.primary_key == pk]

async def create_model(
self,
Expand All @@ -55,10 +69,7 @@ async def create_model(
:param kwargs: Additional model data not included in the pydantic schema.
:return:
"""
if not kwargs:
ins = self.model(**obj.model_dump())
else:
ins = self.model(**obj.model_dump(), **kwargs)
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)

session.add(ins)

Expand Down Expand Up @@ -89,10 +100,7 @@ async def create_models(
"""
ins_list = []
for obj in objs:
if not kwargs:
ins = self.model(**obj.model_dump())
else:
ins = self.model(**obj.model_dump(), **kwargs)
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
ins_list.append(ins)

session.add_all(ins_list)
Expand All @@ -118,12 +126,12 @@ async def count(
:param kwargs: Query expressions.
:return:
"""
filter_list = list(whereclause)
filters = list(whereclause)

if kwargs:
filter_list.extend(parse_filters(self.model, **kwargs))
filters.extend(parse_filters(self.model, **kwargs))

stmt = select(func.count()).select_from(self.model).where(*filter_list)
stmt = select(func.count()).select_from(self.model).where(*filters)
query = await session.execute(stmt)
total_count = query.scalar()
return total_count if total_count is not None else 0
Expand Down Expand Up @@ -154,21 +162,20 @@ async def exists(
async def select_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
*whereclause: ColumnExpressionArgument[bool],
) -> Model | None:
"""
Query by ID
Query by primary key(s)

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: Single value for simple primary key, or tuple for composite primary key.
:param whereclause: The WHERE clauses to apply to the query.
:return:
"""
filter_list = list(whereclause)
_filters = [self.primary_key == pk]
_filters.extend(filter_list)
stmt = select(self.model).where(*_filters)
filters = self._get_pk_filter(pk)
filters + list(whereclause)
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().first()

Expand All @@ -186,10 +193,8 @@ async def select_model_by_column(
:param kwargs: Query expressions.
:return:
"""
filter_list = list(whereclause)
_filters = parse_filters(self.model, **kwargs)
_filters.extend(filter_list)
stmt = select(self.model).where(*_filters)
filters = parse_filters(self.model, **kwargs) + list(whereclause)
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().first()

Expand All @@ -201,10 +206,8 @@ async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -
:param kwargs: Query expressions.
:return:
"""
filter_list = list(whereclause)
_filters = parse_filters(self.model, **kwargs)
_filters.extend(filter_list)
stmt = select(self.model).where(*_filters)
filters = parse_filters(self.model, **kwargs) + list(whereclause)
stmt = select(self.model).where(*filters)
return stmt

async def select_order(
Expand Down Expand Up @@ -270,7 +273,7 @@ async def select_models_order(
async def update_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
obj: UpdateSchema | dict[str, Any],
flush: bool = False,
commit: bool = False,
Expand All @@ -280,21 +283,17 @@ async def update_model(
Update an instance by model's primary key

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: Single value for simple primary key, or tuple for composite primary key.
:param obj: A pydantic schema or dictionary containing the update data
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Additional model data not included in the pydantic schema.
:return:
"""
if isinstance(obj, dict):
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
if kwargs:
instance_data.update(kwargs)

stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
filters = self._get_pk_filter(pk)
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
instance_data.update(kwargs)
stmt = update(self.model).where(*filters).values(**instance_data)
result = await session.execute(stmt)

if flush:
Expand Down Expand Up @@ -325,15 +324,13 @@ async def update_model_by_column(
:return:
"""
filters = parse_filters(self.model, **kwargs)

total_count = await self.count(session, *filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
if isinstance(obj, dict):
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)

stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
stmt = update(self.model).where(*filters).values(**instance_data)
result = await session.execute(stmt)

if flush:
Expand All @@ -346,20 +343,22 @@ async def update_model_by_column(
async def delete_model(
self,
session: AsyncSession,
pk: int,
pk: Any | Sequence[Any],
flush: bool = False,
commit: bool = False,
) -> int:
"""
Delete an instance by model's primary key

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: Single value for simple primary key, or tuple for composite primary key.
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:return:
"""
stmt = delete(self.model).where(self.primary_key == pk)
filters = self._get_pk_filter(pk)

stmt = delete(self.model).where(*filters)
result = await session.execute(stmt)

if flush:
Expand Down Expand Up @@ -392,14 +391,16 @@ async def delete_model_by_column(
:return:
"""
filters = parse_filters(self.model, **kwargs)

total_count = await self.count(session, *filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
if logical_deletion:
deleted_flag = {deleted_flag_column: True}
stmt = update(self.model).where(*filters).values(**deleted_flag)
else:
stmt = delete(self.model).where(*filters)

stmt = (
update(self.model).where(*filters).values(**{deleted_flag_column: True})
if logical_deletion
else delete(self.model).where(*filters)
)

result = await session.execute(stmt)

Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

from tests.model import Base, Ins
from tests.model import Base, Ins, InsPks

_async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True)
_async_session = async_sessionmaker(_async_engine, autoflush=False, expire_on_commit=False)
Expand All @@ -29,3 +29,12 @@ async def create_test_model():
async with _async_session.begin() as session:
data = [Ins(name=f'name_{i}') for i in range(1, 10)]
session.add_all(data)


@pytest_asyncio.fixture
async def create_test_model_pks():
async with _async_session.begin() as session:
data = [InsPks(id=i, name=f'name_{i}', sex='men') for i in range(1, 5)]
session.add_all(data)
data = [InsPks(id=i, name=f'name_{i}', sex='women') for i in range(6, 10)]
session.add_all(data)
11 changes: 11 additions & 0 deletions tests/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,14 @@ class Ins(Base):
del_flag: Mapped[bool] = mapped_column(default=False)
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)


class InsPks(Base):
__tablename__ = 'ins_pks'

id: Mapped[int] = mapped_column(primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(64))
sex: Mapped[str] = mapped_column(String(16), primary_key=True, index=True)
del_flag: Mapped[bool] = mapped_column(default=False)
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)
6 changes: 6 additions & 0 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@

class ModelTest(BaseModel):
name: str


class ModelTestPks(BaseModel):
id: int
name: str
sex: str
Loading