Skip to content

优化主键支持类型,select_model,update_model,delete_model 方法支持非int类型主键和复合主键操作 #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion docs/usage/delete_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def delete_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any]],
flush: bool = False,
commit: bool = False,
) -> int:
Expand All @@ -44,3 +44,43 @@ async def delete_model(
| Type | Description |
|------|-------------|
| int | 删除数量 |


## example

```python
# Model with composite primary key
class UserComposite(Base):
__tablename__ = "users_composite"
id = Column(String, primary_key=True)
name = Column(String, primary_key=True)
email = Column(String)

class UserCreate(BaseModel):
id: str
name: str | None
email: str

async def example(session: AsyncSession):
# Composite primary key model
crud = CRUDPlus(UserComposite)

# Create
await crud.create_model(
session, UserCreate(id="123", name="John", email="[email protected]"), commit=True
)


# Delete by composite primary key (dictionary)
await crud.delete_model(session, {"id": "123", "name": "John"}, commit=True)

# Create
await crud.create_model(
session, UserCreate(id="456", name="Jack", email="[email protected]"), commit=True
)

# Delete by composite primary key (tuple)
await crud.delete_model(session, ("456", "Jack"), commit=True)


```
38 changes: 37 additions & 1 deletion docs/usage/select_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def select_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any]],
*whereclause: ColumnExpressionArgument[bool],
) -> Model | None:
```
Expand All @@ -42,3 +42,39 @@ async def select_model(
| Type | Description |
|---------------------|-------------|
| `TypeVar `\|` None` | 模型实例 |


## example

```python
# Model with composite primary key
class UserComposite(Base):
__tablename__ = "users_composite"
id = Column(String, primary_key=True)
name = Column(String, primary_key=True)
email = Column(String)

class UserCreate(BaseModel):
id: str
name: str | None
email: str

async def example(session: AsyncSession):
# Composite primary key model
crud = CRUDPlus(UserComposite)

# Create
await crud.create_model(
session, UserCreate(id="123", name="John", email="[email protected]"), commit=True
)

# Select by composite primary key (dictionary)
user = await crud.select_model(session, {"id": "123", "name": "John"})
print(user.email) # [email protected]

# Select by composite primary key (tuple)
user = await crud.select_model(session, ("123", "John"))
print(user.email) # [email protected]


```
41 changes: 39 additions & 2 deletions docs/usage/update_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class UpdateIns(BaseModel):


class CRUDIns(CRUDPlus[ModelIns]):
async def create(self, db: AsyncSession, pk: int, obj: UpdateIns) -> int:
async def update(self, db: AsyncSession, pk: Union[Any, Dict[str, Any]], obj: UpdateIns) -> int:
return await self.update_model(db, pk, obj)
```

Expand All @@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
async def update_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any]],
obj: UpdateSchema | dict[str, Any],
flush: bool = False,
commit: bool = False,
Expand Down Expand Up @@ -70,3 +70,40 @@ async def update_model(
| Type | Description |
|------|-------------|
| int | 更新数量 |


## example

```python
# Model with composite primary key
class UserComposite(Base):
__tablename__ = "users_composite"
id = Column(String, primary_key=True)
name = Column(String, primary_key=True)
email = Column(String)

class UserCreate(BaseModel):
id: str
name: str | None
email: str

async def example(session: AsyncSession):
# Composite primary key model
crud = CRUDPlus(UserComposite)

# Create
await crud.create_model(
session, UserCreate(id="123", name="John", email="[email protected]"), commit=True
)

# Update by composite primary key (dictionary)
await crud.update_model(
session, {"id": "123", "name": "John"}, {"email": "[email protected]"}, commit=True
)

# Update by composite primary key (tuple)
await crud.update_model(
session, ("123", "John"), {"email": "[email protected]"}, commit=True
)

```
95 changes: 69 additions & 26 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Generic, Iterable, Sequence, Type
from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict, Tuple

from sqlalchemy import (
Column,
Expand All @@ -25,17 +25,56 @@ class CRUDPlus(Generic[Model]):
def __init__(self, model: Type[Model]):
self.model = model
self.primary_key = self._get_primary_key()
self._pk_column_names = [pk_col.name for pk_col in self.primary_keys] # Cache column names

def _get_primary_key(self) -> Column:
def _get_primary_keys(self) -> list[Column]:
"""
Dynamically retrieve the primary key column(s) for the model.
Retrieve the primary key columns for the model.
"""
mapper = inspect(self.model)
primary_key = mapper.primary_key
if len(primary_key) == 1:
return primary_key[0]
return list(mapper.primary_key)

@property
def primary_key_columns(self) -> list[str]:
"""
Return the names of the primary key columns in order.
"""
return self._pk_column_names

def _validate_pk_input(self, pk: Union[Any, Dict[str, Any], Tuple[Any, ...]]) -> Dict[str, Any]:
"""
Validate and normalize primary key input to a dictionary mapping column names to values.

:param pk: A single value for single primary key, a dictionary, or a tuple for composite primary keys.
:return: Dictionary mapping primary key column names to their values.
:raises ValueError: If the input format is invalid or missing required primary key columns.
"""
if len(self.primary_keys) == 1:
pk_col = self._pk_column_names[0]
if isinstance(pk, dict):
if pk_col not in pk:
raise ValueError(f"Primary key column '{pk_col}' missing in dictionary")
return {pk_col: pk[pk_col]}
return {pk_col: pk}
else:
raise CompositePrimaryKeysError('Composite primary keys are not supported')
if isinstance(pk, dict):
missing = set(self._pk_column_names) - set(pk.keys())
if missing:
raise ValueError(
f"Missing primary key columns: {missing}. Expected keys: {self._pk_column_names}"
)
return {k: v for k, v in pk.items() if k in self._pk_column_names}
elif isinstance(pk, tuple):
if len(pk) != len(self.primary_keys):
raise ValueError(
f"Expected {len(self.primary_keys)} primary key values, got {len(pk)}. "
f"Expected columns: {self._pk_column_names}"
)
return dict(zip(self._pk_column_names, pk))
raise ValueError(
f"Composite primary keys require a dictionary or tuple with keys/values for {self._pk_column_names}, "
f"got {type(pk)}"
)

async def create_model(
self,
Expand Down Expand Up @@ -154,21 +193,22 @@ async def exists(
async def select_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any], Tuple[Any, ...]],
*whereclause: ColumnExpressionArgument[bool],
) -> Model | None:
"""
Query by ID

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: A single value for a single primary key (e.g., int, str), a dictionary
mapping column names to values, or a tuple of values (in column order) for
composite primary keys.
:param whereclause: The WHERE clauses to apply to the query.
:return:
"""
filter_list = list(whereclause)
_filters = [self.primary_key == pk]
_filters.extend(filter_list)
stmt = select(self.model).where(*_filters)
pk_dict = self._validate_pk_input(pk)
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + list(whereclause)
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().first()

Expand Down Expand Up @@ -270,7 +310,7 @@ async def select_models_order(
async def update_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any], Tuple[Any, ...]],
obj: UpdateSchema | dict[str, Any],
flush: bool = False,
commit: bool = False,
Expand All @@ -280,21 +320,20 @@ async def update_model(
Update an instance by model's primary key

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: A single value for a single primary key (e.g., int, str), a dictionary
mapping column names to values, or a tuple of values (in column order) for
composite primary keys.
:param obj: A pydantic schema or dictionary containing the update data
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:param kwargs: Additional model data not included in the pydantic schema.
:return:
"""
if isinstance(obj, dict):
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
if kwargs:
instance_data.update(kwargs)

stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
pk_dict = self._validate_pk_input(pk)
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
instance_data.update(kwargs)
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()]
stmt = update(self.model).where(*filters).values(**instance_data)
result = await session.execute(stmt)

if flush:
Expand Down Expand Up @@ -346,20 +385,24 @@ async def update_model_by_column(
async def delete_model(
self,
session: AsyncSession,
pk: int,
pk: Union[Any, Dict[str, Any], Tuple[Any, ...]],
flush: bool = False,
commit: bool = False,
) -> int:
"""
Delete an instance by model's primary key

:param session: The SQLAlchemy async session.
:param pk: The database primary key value.
:param pk: A single value for a single primary key (e.g., int, str), a dictionary
mapping column names to values, or a tuple of values (in column order) for
composite primary keys.
:param flush: If `True`, flush all object changes to the database. Default is `False`.
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:return:
"""
stmt = delete(self.model).where(self.primary_key == pk)
pk_dict = self._validate_pk_input(pk)
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()]
stmt = delete(self.model).where(*filters)
result = await session.execute(stmt)

if flush:
Expand Down
Loading