Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions python/cocoindex/_internal/memo_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def _is_pydantic_model(obj: object) -> bool:
return hasattr(obj, "__pydantic_fields__") and not isinstance(obj, type) # type: ignore[attr-defined]


def _canonicalize_dataclass(obj: object, _seen: dict[int, int]) -> Fingerprintable:
def _canonicalize_dataclass(
obj: object,
_seen: dict[int, int],
state_methods: list[StateFnEntry],
) -> Fingerprintable:
"""Canonicalize a dataclass instance.

Preserves field definition order and includes all fields.
Expand All @@ -141,13 +145,17 @@ def _canonicalize_dataclass(obj: object, _seen: dict[int, int]) -> Fingerprintab
canonical_module_name(typ),
typ.__qualname__,
tuple(
(field.name, _canonicalize(getattr(obj, field.name), _seen))
(field.name, _canonicalize(getattr(obj, field.name), _seen, state_methods))
for field in fields
),
)


def _canonicalize_pydantic(obj: object, _seen: dict[int, int]) -> Fingerprintable:
def _canonicalize_pydantic(
obj: object,
_seen: dict[int, int],
state_methods: list[StateFnEntry],
) -> Fingerprintable:
"""Canonicalize a Pydantic v2 model instance.

Includes all fields (set and unset) to ensure determinism.
Expand All @@ -159,7 +167,10 @@ def _canonicalize_pydantic(obj: object, _seen: dict[int, int]) -> Fingerprintabl
"pydantic",
canonical_module_name(typ),
typ.__qualname__,
tuple((name, _canonicalize(getattr(obj, name), _seen)) for name in field_names),
tuple(
(name, _canonicalize(getattr(obj, name), _seen, state_methods))
for name in field_names
),
)


Expand Down Expand Up @@ -257,7 +268,7 @@ def _stable_sort_key(v: Fingerprintable) -> tuple[typing.Any, ...]:
def _canonicalize(
obj: object,
_seen: dict[int, int] | None,
state_methods: list[StateFnEntry] | None = None,
state_methods: list[StateFnEntry],
) -> Fingerprintable:
# 0) Cycle / shared-reference tracking for containers
if _seen is None:
Expand All @@ -281,10 +292,9 @@ def _canonicalize(
state_hook = getattr(obj, "__coco_memo_state__", None)
if state_hook is not None and callable(state_hook):
tag = "shook"
if state_methods is not None:
# raw function for type hint extraction (unbound method on class)
raw_fn = getattr(typ, "__coco_memo_state__")
state_methods.append(_make_state_fn_entry(state_hook, raw_fn))
# raw function for type hint extraction (unbound method on class)
raw_fn = getattr(typ, "__coco_memo_state__")
state_methods.append(_make_state_fn_entry(state_hook, raw_fn))
return (
tag,
canonical_module_name(typ),
Expand All @@ -299,9 +309,8 @@ def _canonicalize(
tag = "hook"
if memo.state_fn is not None:
tag = "shook"
if state_methods is not None:
bound = functools.partial(memo.state_fn, obj)
state_methods.append(_make_state_fn_entry(bound, memo.state_fn))
bound = functools.partial(memo.state_fn, obj)
state_methods.append(_make_state_fn_entry(bound, memo.state_fn))
return (
tag,
canonical_module_name(base),
Expand Down Expand Up @@ -342,11 +351,11 @@ def _canonicalize(

# 5) Dataclass instances
if _is_dataclass_instance(obj):
return _canonicalize_dataclass(obj, _seen)
return _canonicalize_dataclass(obj, _seen, state_methods)

# 6) Pydantic v2 models
if _is_pydantic_model(obj):
return _canonicalize_pydantic(obj, _seen)
return _canonicalize_pydantic(obj, _seen, state_methods)

# 7) Fallback
try:
Expand All @@ -364,9 +373,9 @@ def _make_call_canonical(
func: typing.Callable[..., object],
args: tuple[object, ...],
kwargs: dict[str, object],
state_methods: list[StateFnEntry],
*,
version: str | int | None = None,
state_methods: list[StateFnEntry] | None = None,
prefix_args: tuple[object, ...] = (),
) -> Fingerprintable:
function_identity = (
Expand All @@ -393,34 +402,38 @@ def _make_call_canonical(


def memo_fingerprint(obj: object) -> core.Fingerprint:
return core.fingerprint_simple_object(_canonicalize(obj, _seen=None))
# State methods are meaningless for an object-only fingerprint; collect
# into a throwaway list so the canonicalizer signature stays uniform.
return core.fingerprint_simple_object(
_canonicalize(obj, _seen=None, state_methods=[])
)


def fingerprint_call(
func: typing.Callable[..., object],
args: tuple[object, ...],
kwargs: dict[str, object],
state_methods: list[StateFnEntry],
*,
version: str | int | None = None,
state_methods: list[StateFnEntry] | None = None,
prefix_args: tuple[object, ...] = (),
) -> core.Fingerprint:
"""Compute the deterministic fingerprint for a function call.

Returns a `cocoindex._internal.core.Fingerprint` object (Python wrapper around a
stable 16-byte digest). Use `bytes(fp)` or `fp.as_bytes()` to get raw bytes.

If *state_methods* is provided, any state methods discovered during
canonicalization are appended to it (used by the execution layer for memo
state validation).
Any state methods discovered during canonicalization are appended to
*state_methods* (used by the execution layer for memo state validation).
Pass an empty list when state methods aren't needed.
"""

call_key_obj = _make_call_canonical(
func,
args,
kwargs,
state_methods,
version=version,
state_methods=state_methods,
prefix_args=prefix_args,
)
# One Python -> Rust call.
Expand Down
Loading
Loading