Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/brussels/__tests__/mixins/test_primary_key_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_id_column_definition() -> None:
server_default = column.server_default
assert server_default is not None
compiled = cast("Any", server_default).arg.compile(dialect=postgresql.dialect())
assert "gen_random_uuid" in str(compiled)
assert "uuidv7" in str(compiled)


def test_id_not_in_init_signature() -> None:
Expand All @@ -50,7 +50,7 @@ def test_id_not_in_init_signature() -> None:
widget_cls(id=uuid4(), name="widget")


def test_id_default_factory_generates_uuid_on_flush(engine: Engine) -> None:
def test_id_generated_for_sqlite_on_flush(engine: Engine) -> None:
DataclassBase.metadata.create_all(engine)

with Session(engine) as session:
Expand All @@ -59,3 +59,4 @@ def test_id_default_factory_generates_uuid_on_flush(engine: Engine) -> None:
session.flush()

assert isinstance(widget.id, UUID)
assert widget.id.version == 4
68 changes: 60 additions & 8 deletions src/brussels/mixins/primary_key.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from uuid import UUID, uuid4

from sqlalchemy import func
from sqlalchemy.orm import Mapped, MappedAsDataclass, declarative_mixin, mapped_column
from sqlalchemy import Table, event, text
from sqlalchemy.engine import Connection
from sqlalchemy.orm import Mapped, MappedAsDataclass, Session, declarative_mixin, mapped_column
from sqlalchemy.schema import DefaultClause


@declarative_mixin
Expand All @@ -13,25 +15,75 @@ class PrimaryKeyMixin(MappedAsDataclass):
duplicate inheritance is safely handled by Python's MRO (Method Resolution Order).

The id field is excluded from __init__ (init=False) and is automatically
generated both client-side (default_factory=uuid4) and server-side
(server_default=gen_random_uuid()) for maximum compatibility.
generated server-side (uuidv7()) for PostgreSQL and client-side
(uuid4 on insert) for SQLite.

Usage:
class MyModel(DataclassBase, PrimaryKeyMixin, TimestampMixin):
__tablename__ = "my_table"
name: Mapped[str]

The UUID is:
- Generated client-side by default (uuid4)
- Has server-side fallback (gen_random_uuid() for PostgreSQL)
- Generated server-side by default (uuidv7() for PostgreSQL)
- Generated client-side on insert (uuid4 for SQLite)
- Indexed and unique for efficient lookups
"""

id: Mapped[UUID] = mapped_column(
primary_key=True,
default_factory=uuid4,
server_default=func.gen_random_uuid(),
server_default=text("uuidv7()"),
index=True,
unique=True,
init=False,
)


_UUIDV7_SERVER_DEFAULT_KEY = "uuidv7_server_default"


@event.listens_for(Session, "before_flush")
def _assign_sqlite_uuid(
session: Session,
_flush_context: object,
_instances: object,
) -> None:
bind = session.get_bind()
if bind is None or bind.dialect.name != "sqlite":
return
for instance in session.new:
if isinstance(instance, PrimaryKeyMixin) and instance.id is None:
instance.id = uuid4()


@event.listens_for(Table, "before_create")
def _strip_uuidv7_server_default(table: Table, connection: Connection, **_kwargs: object) -> None:
if connection.dialect.name == "postgresql":
return

column = table.columns.get("id")
if column is None or column.server_default is None:
return

if not isinstance(column.server_default, DefaultClause):
return
if str(column.server_default.arg) != "uuidv7()":
return

column.info[_UUIDV7_SERVER_DEFAULT_KEY] = column.server_default
column.server_default = None


@event.listens_for(Table, "after_create")
def _restore_uuidv7_server_default(table: Table, connection: Connection, **_kwargs: object) -> None:
if connection.dialect.name == "postgresql":
return

column = table.columns.get("id")
if column is None:
return

server_default = column.info.pop(_UUIDV7_SERVER_DEFAULT_KEY, None)
if server_default is None:
return

column.server_default = server_default