Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/brussels/__tests__/mixins/test_primary_key_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy.orm import Mapped, Session, mapped_column

from brussels.base import DataclassBase
from brussels.mixins import PrimaryKeyMixin, TimestampMixin
from brussels.mixins import PrimaryKeyMixin, TimestampMixin, UUIDv7PrimaryKeyMixin


class Widget(DataclassBase, PrimaryKeyMixin, TimestampMixin):
Expand All @@ -18,6 +18,12 @@ class Widget(DataclassBase, PrimaryKeyMixin, TimestampMixin):
name: Mapped[str] = mapped_column()


class UUIDv7Widget(DataclassBase, UUIDv7PrimaryKeyMixin, TimestampMixin):
__tablename__ = "primary_key_v7_widgets"

name: Mapped[str] = mapped_column()


@pytest.fixture
def engine() -> Iterator[Engine]:
engine = create_engine("sqlite:///:memory:")
Expand Down Expand Up @@ -57,3 +63,41 @@ def test_id_default_factory_generates_uuid_on_flush(engine: Engine) -> None:
session.flush()

assert isinstance(widget.id, UUID)


def test_uuidv7_id_column_definition() -> None:
table = cast("Table", UUIDv7Widget.__table__)
column = table.c.id

assert column.primary_key is True

insert_default = column.default
assert insert_default is not None
compiled_insert_default = cast("Any", insert_default).arg.compile(dialect=postgresql.dialect())
assert "uuidv7" in str(compiled_insert_default)

server_default = column.server_default
assert server_default is not None
compiled_server_default = cast("Any", server_default).arg.compile(dialect=postgresql.dialect())
assert "uuidv7" in str(compiled_server_default)


def test_uuidv7_id_not_in_init_signature() -> None:
signature = inspect.signature(UUIDv7Widget)
assert "id" not in signature.parameters

widget_cls = cast("Any", UUIDv7Widget)
with pytest.raises(TypeError):
widget_cls(id=uuid4(), name="widget")


def test_uuidv7_id_is_not_populated_before_flush() -> None:
widget = UUIDv7Widget(name="widget")

assert widget.id is None


def test_uuidv7_models_satisfy_primary_key_mixin_contract() -> None:
widget = UUIDv7Widget(name="widget")

assert isinstance(widget, PrimaryKeyMixin)
146 changes: 144 additions & 2 deletions src/brussels/__tests__/types/file/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from datetime import UTC, datetime
from typing import TYPE_CHECKING, cast
from uuid import UUID
from uuid import UUID, uuid4

import pytest
from sqlalchemy.orm import Mapped, Session, mapped_column

from brussels.base import Base, DataclassBase
from brussels.mixins import PrimaryKeyMixin
from brussels.mixins import PrimaryKeyMixin, UUIDv7PrimaryKeyMixin

try:
from obstore.store import MemoryStore
Expand All @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from obstore import GetOptions, PutMode
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncSession

from brussels.types.file._types import RemoteMetadataField

Expand All @@ -45,6 +46,12 @@ class OtherFileModel(DataclassBase, PrimaryKeyMixin):
file: Mapped[RemoteMetadata | None] = mapped_column(RemoteStorage(store=MemoryStore()), nullable=True, default=None)


class UUIDv7FileModel(DataclassBase, UUIDv7PrimaryKeyMixin):
__tablename__ = "uuidv7_file_model"

file: Mapped[RemoteMetadata | None] = mapped_column(RemoteStorage(store=MemoryStore()), nullable=True, default=None)


class TablelessPrimaryKeyModel(PrimaryKeyMixin):
pass

Expand All @@ -54,10 +61,29 @@ def _configure_store(store_ops: FakeStoreOps) -> None:
remote_storage.store = store_ops


def _configure_uuidv7_store(store_ops: FakeStoreOps) -> None:
remote_storage = cast("RemoteStorage", UUIDv7FileModel.__table__.c["file"].type)
remote_storage.store = store_ops


def _file_handle(model: FileModel) -> RemoteFile:
return RemoteFile.from_metadata(model, FileModel.file)


def _uuidv7_file_handle(model: UUIDv7FileModel) -> RemoteFile:
return RemoteFile.from_metadata(model, UUIDv7FileModel.file)


class AsyncSessionShim:
def __init__(self, sync_session: Session, *, on_flush=None) -> None:
self.sync_session = sync_session
self._on_flush = on_flush

async def flush(self) -> None:
if self._on_flush is not None:
self._on_flush()


def test_from_metadata_rejects_models_without_primary_key_mixin() -> None:
model = NoPrimaryKeyMixinModel(id=1)

Expand Down Expand Up @@ -105,6 +131,14 @@ def test_from_metadata_accepts_models_with_primary_key_mixin() -> None:
assert remote_file.field_name == "file"


def test_from_metadata_accepts_uuidv7_primary_key_models() -> None:
model = UUIDv7FileModel()

remote_file = RemoteFile.from_metadata(model, UUIDv7FileModel.file)

assert remote_file.field_name == "file"


def test_put_sync_without_sqlalchemy_session_raises_and_does_not_call_store() -> None:
store_ops = FakeStoreOps()
_configure_store(store_ops)
Expand All @@ -120,6 +154,62 @@ def test_put_sync_without_sqlalchemy_session_raises_and_does_not_call_store() ->
assert model.file is None


def test_put_sync_flush_populates_uuidv7_id_before_key_generation(
engine: Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
store_ops = FakeStoreOps()
_configure_uuidv7_store(store_ops)

with Session(engine) as session:
model = UUIDv7FileModel()
session.add(model)
assigned_id = uuid4()
flush_calls = 0

def fake_flush() -> None:
nonlocal flush_calls
flush_calls += 1
if model.id is None:
model.id = assigned_id

monkeypatch.setattr(session, "flush", fake_flush)

put_result = _uuidv7_file_handle(model).put(
b"hello",
content_type="text/plain",
session=session,
flush=True,
)

assert put_result == {"e_tag": None, "version": None}
assert flush_calls == 2
assert model.id == assigned_id
assert model.file is not None
assert model.file.key == f"{assigned_id}/file"
assert model.file.status == "pending"
assert store_ops.calls == []


def test_put_sync_rejects_uuidv7_model_without_flush(engine: Engine) -> None:
store_ops = FakeStoreOps()
_configure_uuidv7_store(store_ops)

with Session(engine) as session:
model = UUIDv7FileModel()
session.add(model)

with pytest.raises(ValueError, match=r"Pass flush=True or flush the model before calling put"):
_uuidv7_file_handle(model).put(
b"hello",
content_type="text/plain",
session=session,
)

assert model.file is None
assert store_ops.calls == []


def test_put_sync_defers_and_commits_when_sqlalchemy_session_is_attached(engine: Engine) -> None:
store_ops = FakeStoreOps()
_configure_store(store_ops)
Expand Down Expand Up @@ -256,6 +346,58 @@ async def test_put_rejects_invalid_model_id_type(engine: Engine) -> None:
await _file_handle(model).put_async(b"data")


@pytest.mark.asyncio
async def test_put_async_flush_populates_uuidv7_id_before_key_generation(engine: Engine) -> None:
store_ops = FakeStoreOps()
_configure_uuidv7_store(store_ops)

with Session(engine) as session:
model = UUIDv7FileModel()
session.add(model)
assigned_id = uuid4()
flush_calls = 0

def on_flush() -> None:
nonlocal flush_calls
flush_calls += 1
if model.id is None:
model.id = assigned_id

await _uuidv7_file_handle(model).put_async(
b"data",
content_type="text/plain",
session=cast("AsyncSession", AsyncSessionShim(session, on_flush=on_flush)),
flush=True,
)

assert flush_calls == 2
assert model.id == assigned_id
assert model.file is not None
assert model.file.key == f"{assigned_id}/file"
assert model.file.status == "pending"
assert store_ops.calls == []


@pytest.mark.asyncio
async def test_put_async_rejects_uuidv7_model_without_flush(engine: Engine) -> None:
store_ops = FakeStoreOps()
_configure_uuidv7_store(store_ops)

with Session(engine) as session:
model = UUIDv7FileModel()
session.add(model)

with pytest.raises(ValueError, match=r"Pass flush=True or flush the model before calling put_async"):
await _uuidv7_file_handle(model).put_async(
b"data",
content_type="text/plain",
session=session,
)

assert model.file is None
assert store_ops.calls == []


@pytest.mark.asyncio
async def test_put_allows_uuid_model_id(engine: Engine) -> None:
store_ops = FakeStoreOps()
Expand Down
4 changes: 2 additions & 2 deletions src/brussels/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from brussels.mixins.ordered import OrderedMixin
from brussels.mixins.primary_key import PrimaryKeyMixin
from brussels.mixins.primary_key import PrimaryKeyMixin, UUIDv7PrimaryKeyMixin
from brussels.mixins.timestamp import TimestampMixin

__all__ = ["OrderedMixin", "PrimaryKeyMixin", "TimestampMixin"]
__all__ = ["OrderedMixin", "PrimaryKeyMixin", "TimestampMixin", "UUIDv7PrimaryKeyMixin"]
27 changes: 27 additions & 0 deletions src/brussels/mixins/primary_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,30 @@ class MyModel(DataclassBase, PrimaryKeyMixin, TimestampMixin):
server_default=func.gen_random_uuid(),
init=False,
)


@declarative_mixin
class UUIDv7PrimaryKeyMixin(PrimaryKeyMixin):
"""Mixin that adds a PostgreSQL 18+ UUIDv7 primary key column.

Extends PrimaryKeyMixin so models continue to satisfy APIs that require the
existing primary-key mixin contract. The id field is excluded from __init__
(init=False) and is generated during insert using PostgreSQL's uuidv7()
function. Because the UUID is database generated, ID-derived workflows
require the row to be flushed or already persisted.

Usage:
class MyModel(DataclassBase, UUIDv7PrimaryKeyMixin, TimestampMixin):
__tablename__ = "my_table"
name: Mapped[str]

This mixin is only supported on PostgreSQL 18+ because it relies on the
built-in uuidv7() database function.
"""

id: Mapped[UUID] = mapped_column(
primary_key=True,
insert_default=func.uuidv7(),
server_default=func.uuidv7(),
init=False,
)
55 changes: 55 additions & 0 deletions src/brussels/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def _model_id(self) -> str:
raise TypeError(msg)
return str(model_id)

@staticmethod
def _missing_model_id_message(*, operation: str) -> str:
return (
f"RemoteStorage {operation} requires model.id to be set. "
f"Pass flush=True or flush the model before calling {operation}."
)

@staticmethod
def _flush_sync(*, session: Session | None, flush: bool) -> None:
if not flush or session is None:
Expand All @@ -86,6 +93,44 @@ async def _flush_async(*, session: Session | AsyncSession | None, flush: bool) -
if isawaitable(maybe_awaitable):
await maybe_awaitable

def _ensure_model_id_ready_sync(
self,
*,
session: Session | None,
flush: bool,
operation: str,
) -> None:
if getattr(self.model, "id", None) is not None:
return
if not flush:
raise ValueError(self._missing_model_id_message(operation=operation))

self._flush_sync(session=session, flush=True)
if getattr(self.model, "id", None) is not None:
return

msg = f"RemoteStorage {operation} requires model.id to be set after flush."
raise ValueError(msg)

async def _ensure_model_id_ready_async(
self,
*,
session: Session | AsyncSession | None,
flush: bool,
operation: str,
) -> None:
if getattr(self.model, "id", None) is not None:
return
if not flush:
raise ValueError(self._missing_model_id_message(operation=operation))

await self._flush_async(session=session, flush=True)
if getattr(self.model, "id", None) is not None:
return

msg = f"RemoteStorage {operation} requires model.id to be set after flush."
raise ValueError(msg)

@staticmethod
def _resolve_sync_session(
*,
Expand Down Expand Up @@ -168,6 +213,11 @@ def put( # noqa: PLR0913
flush: bool = False,
) -> PutResult:
resolved_session = self._required_sync_session(session=session, operation="put")
self._ensure_model_id_ready_sync(
session=session or resolved_session,
flush=flush,
operation="put",
)
metadata = self._prepare_pending_metadata(
key=key,
content_type=content_type,
Expand Down Expand Up @@ -207,6 +257,11 @@ async def put_async( # noqa: PLR0913
flush: bool = False,
) -> PutResult:
resolved_session = self._required_sync_session(session=session, operation="put")
await self._ensure_model_id_ready_async(
session=session or resolved_session,
flush=flush,
operation="put_async",
)
metadata = self._prepare_pending_metadata(
key=key,
content_type=content_type,
Expand Down
Loading