diff --git a/docs/advanced/primary_key.md b/docs/advanced/primary_key.md index e5fc028..775c188 100644 --- a/docs/advanced/primary_key.md +++ b/docs/advanced/primary_key.md @@ -2,20 +2,28 @@ 由于在 python 内部 `id` 为关键字,因此,我们设定默认主键入参为 `pk`。这仅用于函数入参,并不要求模型主键必须定义为 `pk` -```py title="e.g." hl_lines="2" -async def delete(self, db: AsyncSession, primary_key: int) -> int: - return self.delete_model(db, pk=primary_key) -``` - -## 主键定义 - !!! tip 自动主键 我们在 SQLAlchemy CRUD Plus 内部通过 [inspect()](https://docs.sqlalchemy.org/en/20/core/inspection.html) 自动搜索表主键, 而非强制绑定主键列必须命名为 `id` -```py title="e.g." hl_lines="4" +## 单个主键 + +```py title="e.g." class ModelIns(Base): # define primary_key primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) + + +class ModelIns2(Base): + # define primary_key + primary_key: Mapped[str] = mapped_column(primary_key=True, index=True) +``` + +## 复合主键 + +```python title="e.g." +class ModelIns(Base): + primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) + primary_key2: Mapped[str] = mapped_column(primary_key=True, index=True) ``` diff --git a/docs/usage/delete_model.md b/docs/usage/delete_model.md index 0cbc0cf..1bc3a2d 100644 --- a/docs/usage/delete_model.md +++ b/docs/usage/delete_model.md @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def delete_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], flush: bool = False, commit: bool = False, ) -> int: @@ -32,12 +32,12 @@ async def delete_model( **Parameters:** -| Name | Type | Description | Default | -|---------|--------------|----------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| pk | int | [主键](../advanced/primary_key.md) | 必填 | -| flush | bool | [冲洗](../advanced/flush.md) | `False` | -| commit | bool | [提交](../advanced/commit.md) | `False` | +| Name | Type | Description | Default | +|---------|--------------------------|----------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 | +| flush | bool | [冲洗](../advanced/flush.md) | `False` | +| commit | bool | [提交](../advanced/commit.md) | `False` | **Returns:** diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 68f5ebe..fdb4d30 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -24,17 +24,17 @@ class CRUDIns(CRUDPlus[ModelIns]): async def select_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: ``` **Parameters:** -| Name | Type | Description | Default | -|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| pk | int | [主键](../advanced/primary_key.md) | 必填 | +| Name | Type | Description | Default | +|--------------|----------------------------------|-----------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 | | *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | **Returns:** diff --git a/docs/usage/update_model.md b/docs/usage/update_model.md index c5a34e6..81a47c0 100644 --- a/docs/usage/update_model.md +++ b/docs/usage/update_model.md @@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def update_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -44,7 +44,7 @@ async def update_model( | Name | Type | Description | Default | |---------|-------------------------------|----------------------------------|---------| | session | AsyncSession | 数据库会话 | 必填 | -| pk | int | [主键](../advanced/primary_key.md) | 必填 | +| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 | | obj | `TypeVar `\|` dict[str, Any]` | 更新数据参数 | 必填 | | flush | bool | [冲洗](../advanced/flush.md) | `False` | | commit | bool | [提交](../advanced/commit.md) | `False` | diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index cecf393..9329416 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -26,7 +26,7 @@ def __init__(self, model: Type[Model]): self.model = model self.primary_key = self._get_primary_key() - def _get_primary_key(self) -> Column: + def _get_primary_key(self) -> Column | list[Column]: """ Dynamically retrieve the primary key column(s) for the model. """ @@ -35,7 +35,21 @@ def _get_primary_key(self) -> Column: if len(primary_key) == 1: return primary_key[0] else: - raise CompositePrimaryKeysError('Composite primary keys are not supported') + return list(primary_key) + + def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]: + """ + Get the primary key filter(s). + + :param pk: Single value for simple primary key, or tuple for composite primary key. + :return: + """ + if isinstance(self.primary_key, list): + if len(pk) != len(self.primary_key): + raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key') + return [column == value for column, value in zip(self.primary_key, pk)] + else: + return [self.primary_key == pk] async def create_model( self, @@ -55,10 +69,7 @@ async def create_model( :param kwargs: Additional model data not included in the pydantic schema. :return: """ - if not kwargs: - ins = self.model(**obj.model_dump()) - else: - ins = self.model(**obj.model_dump(), **kwargs) + ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs) session.add(ins) @@ -89,10 +100,7 @@ async def create_models( """ ins_list = [] for obj in objs: - if not kwargs: - ins = self.model(**obj.model_dump()) - else: - ins = self.model(**obj.model_dump(), **kwargs) + ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs) ins_list.append(ins) session.add_all(ins_list) @@ -118,12 +126,12 @@ async def count( :param kwargs: Query expressions. :return: """ - filter_list = list(whereclause) + filters = list(whereclause) if kwargs: - filter_list.extend(parse_filters(self.model, **kwargs)) + filters.extend(parse_filters(self.model, **kwargs)) - stmt = select(func.count()).select_from(self.model).where(*filter_list) + stmt = select(func.count()).select_from(self.model).where(*filters) query = await session.execute(stmt) total_count = query.scalar() return total_count if total_count is not None else 0 @@ -154,21 +162,20 @@ async def exists( async def select_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: """ - Query by ID + Query by primary key(s) :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: Single value for simple primary key, or tuple for composite primary key. :param whereclause: The WHERE clauses to apply to the query. :return: """ - filter_list = list(whereclause) - _filters = [self.primary_key == pk] - _filters.extend(filter_list) - stmt = select(self.model).where(*_filters) + filters = self._get_pk_filter(pk) + filters + list(whereclause) + stmt = select(self.model).where(*filters) query = await session.execute(stmt) return query.scalars().first() @@ -186,10 +193,8 @@ async def select_model_by_column( :param kwargs: Query expressions. :return: """ - filter_list = list(whereclause) - _filters = parse_filters(self.model, **kwargs) - _filters.extend(filter_list) - stmt = select(self.model).where(*_filters) + filters = parse_filters(self.model, **kwargs) + list(whereclause) + stmt = select(self.model).where(*filters) query = await session.execute(stmt) return query.scalars().first() @@ -201,10 +206,8 @@ async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) - :param kwargs: Query expressions. :return: """ - filter_list = list(whereclause) - _filters = parse_filters(self.model, **kwargs) - _filters.extend(filter_list) - stmt = select(self.model).where(*_filters) + filters = parse_filters(self.model, **kwargs) + list(whereclause) + stmt = select(self.model).where(*filters) return stmt async def select_order( @@ -270,7 +273,7 @@ async def select_models_order( async def update_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -280,21 +283,17 @@ async def update_model( Update an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: Single value for simple primary key, or tuple for composite primary key. :param obj: A pydantic schema or dictionary containing the update data :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. :param kwargs: Additional model data not included in the pydantic schema. :return: """ - if isinstance(obj, dict): - instance_data = obj - else: - instance_data = obj.model_dump(exclude_unset=True) - if kwargs: - instance_data.update(kwargs) - - stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) + filters = self._get_pk_filter(pk) + instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + instance_data.update(kwargs) + stmt = update(self.model).where(*filters).values(**instance_data) result = await session.execute(stmt) if flush: @@ -325,15 +324,13 @@ async def update_model_by_column( :return: """ filters = parse_filters(self.model, **kwargs) + total_count = await self.count(session, *filters) if not allow_multiple and total_count > 1: raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.') - if isinstance(obj, dict): - instance_data = obj - else: - instance_data = obj.model_dump(exclude_unset=True) - stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore + instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + stmt = update(self.model).where(*filters).values(**instance_data) result = await session.execute(stmt) if flush: @@ -346,7 +343,7 @@ async def update_model_by_column( async def delete_model( self, session: AsyncSession, - pk: int, + pk: Any | Sequence[Any], flush: bool = False, commit: bool = False, ) -> int: @@ -354,12 +351,14 @@ async def delete_model( Delete an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: Single value for simple primary key, or tuple for composite primary key. :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: """ - stmt = delete(self.model).where(self.primary_key == pk) + filters = self._get_pk_filter(pk) + + stmt = delete(self.model).where(*filters) result = await session.execute(stmt) if flush: @@ -392,14 +391,16 @@ async def delete_model_by_column( :return: """ filters = parse_filters(self.model, **kwargs) + total_count = await self.count(session, *filters) if not allow_multiple and total_count > 1: raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.') - if logical_deletion: - deleted_flag = {deleted_flag_column: True} - stmt = update(self.model).where(*filters).values(**deleted_flag) - else: - stmt = delete(self.model).where(*filters) + + stmt = ( + update(self.model).where(*filters).values(**{deleted_flag_column: True}) + if logical_deletion + else delete(self.model).where(*filters) + ) result = await session.execute(stmt) diff --git a/tests/conftest.py b/tests/conftest.py index 6e16e3e..a0cf431 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from tests.model import Base, Ins +from tests.model import Base, Ins, InsPks _async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True) _async_session = async_sessionmaker(_async_engine, autoflush=False, expire_on_commit=False) @@ -29,3 +29,12 @@ async def create_test_model(): async with _async_session.begin() as session: data = [Ins(name=f'name_{i}') for i in range(1, 10)] session.add_all(data) + + +@pytest_asyncio.fixture +async def create_test_model_pks(): + async with _async_session.begin() as session: + data = [InsPks(id=i, name=f'name_{i}', sex='men') for i in range(1, 5)] + session.add_all(data) + data = [InsPks(id=i, name=f'name_{i}', sex='women') for i in range(6, 10)] + session.add_all(data) diff --git a/tests/model.py b/tests/model.py index c512de9..ec221da 100644 --- a/tests/model.py +++ b/tests/model.py @@ -20,3 +20,14 @@ class Ins(Base): del_flag: Mapped[bool] = mapped_column(default=False) created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now) updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now) + + +class InsPks(Base): + __tablename__ = 'ins_pks' + + id: Mapped[int] = mapped_column(primary_key=True, index=True) + name: Mapped[str] = mapped_column(String(64)) + sex: Mapped[str] = mapped_column(String(16), primary_key=True, index=True) + del_flag: Mapped[bool] = mapped_column(default=False) + created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now) + updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now) diff --git a/tests/schema.py b/tests/schema.py index e85bf46..a508be8 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -5,3 +5,9 @@ class ModelTest(BaseModel): name: str + + +class ModelTestPks(BaseModel): + id: int + name: str + sex: str diff --git a/tests/test_create.py b/tests/test_create.py index 92740df..d13cf8e 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from random import choice + import pytest from sqlalchemy import select from sqlalchemy_crud_plus import CRUDPlus -from tests.model import Ins -from tests.schema import ModelTest +from tests.model import Ins, InsPks +from tests.schema import ModelTest, ModelTestPks @pytest.mark.asyncio @@ -34,3 +36,30 @@ async def test_create_models(async_db_session): for i in range(1, 10): query = await session.scalar(select(Ins).where(Ins.id == i)) assert query.name == f'name_{i}' + + +@pytest.mark.asyncio +async def test_create_model_pks(async_db_session): + async with async_db_session.begin() as session: + crud = CRUDPlus(InsPks) + for i in range(1, 10): + data = ModelTestPks(id=i, name=f'name_{i}', sex=choice(['men', 'women'])) + await crud.create_model(session, data) + async with async_db_session() as session: + for i in range(1, 10): + query = await session.scalar(select(InsPks).where(InsPks.id == i)) + assert query.name == f'name_{i}' + + +@pytest.mark.asyncio +async def test_create_models_pks(async_db_session): + async with async_db_session.begin() as session: + crud = CRUDPlus(InsPks) + data = [] + for i in range(1, 10): + data.append(ModelTestPks(id=i, name=f'name_{i}', sex=choice(['men', 'women']))) + await crud.create_models(session, data) + async with async_db_session() as session: + for i in range(1, 10): + query = await session.scalar(select(InsPks).where(InsPks.id == i)) + assert query.name == f'name_{i}' diff --git a/tests/test_delete.py b/tests/test_delete.py index f91963e..96184fd 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -3,7 +3,7 @@ import pytest from sqlalchemy_crud_plus import CRUDPlus -from tests.model import Ins +from tests.model import Ins, InsPks @pytest.mark.asyncio @@ -14,6 +14,14 @@ async def test_delete_model(create_test_model, async_db_session): assert result == 1 +@pytest.mark.asyncio +async def test_delete_model_pks(create_test_model_pks, async_db_session): + async with async_db_session.begin() as session: + crud = CRUDPlus(InsPks) + result = await crud.delete_model(session, (1, 'men')) + assert result == 1 + + @pytest.mark.asyncio async def test_delete_model_by_column(create_test_model, async_db_session): async with async_db_session.begin() as session: diff --git a/tests/test_select.py b/tests/test_select.py index 9f7dc9c..025ce3c 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -5,7 +5,7 @@ from sqlalchemy import Select from sqlalchemy_crud_plus import CRUDPlus -from tests.model import Ins +from tests.model import Ins, InsPks @pytest.mark.asyncio @@ -62,6 +62,18 @@ async def test_select_model(create_test_model, async_db_session): assert result.name == f'name_{i}' +@pytest.mark.asyncio +async def test_select_model_pks(create_test_model_pks, async_db_session): + async with async_db_session() as session: + crud = CRUDPlus(InsPks) + for i in range(1, 5): + result = await crud.select_model(session, (i, 'men')) + assert result.name == f'name_{i}' + for i in range(6, 10): + result = await crud.select_model(session, (i, 'women')) + assert result.name == f'name_{i}' + + @pytest.mark.asyncio async def test_select_model_by_column(create_test_model, async_db_session): async with async_db_session() as session: diff --git a/tests/test_update.py b/tests/test_update.py index 309e152..59e1e8f 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -3,8 +3,8 @@ import pytest from sqlalchemy_crud_plus import CRUDPlus -from tests.model import Ins -from tests.schema import ModelTest +from tests.model import Ins, InsPks +from tests.schema import ModelTest, ModelTestPks @pytest.mark.asyncio @@ -18,6 +18,17 @@ async def test_update_model(create_test_model, async_db_session): assert result.name == 'name_update_1' +@pytest.mark.asyncio +async def test_update_model_pks(create_test_model_pks, async_db_session): + async with async_db_session.begin() as session: + crud = CRUDPlus(InsPks) + data = ModelTestPks(id=1, name='name_update_1', sex='men') + result = await crud.update_model(session, (1, 'men'), data) + assert result == 1 + result = await session.get(InsPks, (1, 'men')) + assert result.name == 'name_update_1' + + @pytest.mark.asyncio async def test_update_model_by_column(create_test_model, async_db_session): async with async_db_session.begin() as session: diff --git a/uv.lock b/uv.lock index fd19857..cc305b8 100644 --- a/uv.lock +++ b/uv.lock @@ -394,7 +394,7 @@ wheels = [ [[package]] name = "sqlalchemy-crud-plus" -version = "1.5.0" +version = "1.8.0" source = { editable = "." } dependencies = [ { name = "pydantic" },