|
| 1 | +--- |
| 2 | +title: Pydantic models in SQLAlchemy |
| 3 | +tags: |
| 4 | + - observations |
| 5 | +--- |
| 6 | +Been using this snippet a lot recently in SQLAlchemy to have Pydantic models be de/serialized to JSON transparently for a nice abstraction. |
| 7 | + |
| 8 | +```python |
| 9 | +from sqlalchemy import DateTime, Dialect, func |
| 10 | +from sqlalchemy.types import JSON, TypeDecorator, TypeEngine |
| 11 | +from sqlalchemy.dialects.postgresql import JSONB |
| 12 | +from pydantic import BaseModel |
| 13 | + |
| 14 | + |
| 15 | +class PydanticModelType[T: BaseModel](TypeDecorator[T]): |
| 16 | + cache_ok = True |
| 17 | + impl = JSON() |
| 18 | + |
| 19 | + def __init__(self, pydantic_type: type[T]) -> None: |
| 20 | + self.pydantic_type = pydantic_type |
| 21 | + super().__init__() |
| 22 | + |
| 23 | + @override |
| 24 | + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[T]: |
| 25 | + # Use JSONB for PostgreSQL and JSON for other databases. |
| 26 | + if dialect.name == "postgresql": |
| 27 | + return dialect.type_descriptor(JSONB()) |
| 28 | + |
| 29 | + return dialect.type_descriptor(JSON()) |
| 30 | + |
| 31 | + @override |
| 32 | + def process_bind_param(self, value: T | None, dialect: Dialect) -> dict[str, Any] | None: |
| 33 | + if value is not None: |
| 34 | + return value.model_dump() |
| 35 | + return value |
| 36 | + |
| 37 | + @override |
| 38 | + def process_result_value(self, value: dict[str, Any] | None, dialect: Dialect) -> T | None: |
| 39 | + if value is not None: |
| 40 | + return self.pydantic_type.model_validate(value) |
| 41 | + return value |
| 42 | +``` |
| 43 | + |
| 44 | +If using python<3.12, the `T` can be replaced with `T = TypeVar("T", bound=BaseModel)` and it should also subclass `Generic[T]`. |
| 45 | + |
| 46 | +Then, inside a subclass of `DeclarativeBase`, you can use with the modern `Mapped/mapped_column` syntax like: |
| 47 | + |
| 48 | +```python |
| 49 | +from sqlalchemy.orm import ( |
| 50 | + DeclarativeBase, |
| 51 | + Mapped, |
| 52 | + mapped_column, |
| 53 | +) |
| 54 | +from pydantic import BaseModel |
| 55 | + |
| 56 | +class BaseTable(DeclarativeBase): |
| 57 | + pass |
| 58 | + |
| 59 | +class MyModel(BaseModel): |
| 60 | + pass |
| 61 | + |
| 62 | +class MyTable(BaseTable): |
| 63 | + __tablename__ = "my_table" |
| 64 | + config: Mapped[MyModel] = mapped_column(PydanticModelType(MyModel)) |
| 65 | +``` |
| 66 | + |
| 67 | +Unfortunately, SQLAlchemy doesn't have a nice way to infer that `MyModel` should be passed to `PydanticModelType` in an easy way, so we have to repeat ourselves. |
| 68 | + |
| 69 | +Additionally, when instantiating your SQLAlchemy engine, you need to pass Pydantic's JSON serializer. Otherwise, it will use Python's default JSON serializer, causing it to freak out. |
| 70 | + |
| 71 | +```python |
| 72 | +from pydantic_core import to_json |
| 73 | + |
| 74 | +def _json_serializer(obj: Any) -> str: |
| 75 | + return to_json(obj).decode("utf-8") |
| 76 | + |
| 77 | +engine = create_engine(..., json_serializer=_json_serializer) |
| 78 | +``` |
| 79 | + |
| 80 | +Note that there isn't a really good way to do change detection via the SQLAlchemy Session API, so you should always manually track your changes and commit when needed. |
| 81 | + |
| 82 | +This package looks like an interesting way to solve the problem, but I have no clue if it works: [`pydantic-changedetect`](https://github.com/team23/pydantic-changedetect/tree/main). |
0 commit comments