Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import logging
import stat
import warnings
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Generator, Iterable, Iterator
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import (
Any,
Literal,
Protocol,
assert_never,
cast,
overload,
)
Expand All @@ -37,7 +39,7 @@ def orjson_dumps(*args: Any, **kwargs: Any) -> bytes: # type: ignore[misc]
from .backend.common import AbstractBackend
from .backend.file import FileBackend
from .backend.sqlite import SqliteBackend
from .common import DEPENDENCIES, CacheReadError, CachewException, SourceHash
from .common import DEPENDENCIES, CacheReadError, CachewException, CacheWriteError, SourceHash
from .logging_helper import make_logger
from .marshall.cachew import CachewMarshall

Expand Down Expand Up @@ -414,6 +416,25 @@ def composite_hash(self, *args, **kwargs) -> dict[str, Any]:
return hash_parts


class _StreamState(Enum):
"""
Streaming lifecycle for cachew_wrapper.
Allowed transitions:
SETUP -> SOURCE when cachew deliberately yields from the wrapped function without caching.
SETUP -> CACHE_WRITE after cachew gets an exclusive write transaction.
SETUP -> FINISHED after a complete cache hit.
CACHE_WRITE -> SOURCE after a recoverable cache write failure; the same source iterator resumes uncached.
CACHE_WRITE -> FINISHED after the source iterator is exhausted and all items have been yielded.
SOURCE -> FINISHED after uncached streaming completes.
FINISHED is terminal; cleanup errors may be logged or raised, but must not trigger fallback that emits more items.
"""

SETUP = auto()
SOURCE = auto()
CACHE_WRITE = auto()
FINISHED = auto()


@dataclass
class CacheSession[ItemT]:
"""
Expand Down Expand Up @@ -446,7 +467,7 @@ def cached_items(self) -> Iterator[ItemT]:
f'failed to read cachew cache ({self.backend_name}:{self.resolved_cache_path}); remove the cache and try again'
) from e

def write_to_cache(self, datas: Iterable[ItemT]) -> Iterator[ItemT]:
def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int]:
if isinstance(self.backend, FileBackend):
# FIXME uhhh.. this is a bit crap
# but in sqlite mode we don't want to publish new hash before we write new items
Expand All @@ -471,13 +492,19 @@ def flush() -> None:
total_objects += 1
yield obj

dct = self.marshall.dump(obj)
blob = orjson_dumps(dct)
try:
dct = self.marshall.dump(obj)
blob = orjson_dumps(dct)
except Exception as e:
msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})'
raise CacheWriteError(msg) from e
chunk.append(blob)
if len(chunk) >= self.chunk_by:
flush()
flush()
return total_objects

def finalize_cache(self, *, total_objects: int) -> None:
self.backend.finalize(self.new_hash)
self.logger.info(
f'wrote {total_objects} objects to cachew ({self.backend_name}:{self.resolved_cache_path})'
Expand Down Expand Up @@ -522,8 +549,7 @@ def cachew_wrapper[**P, ItemT](
# but it lets us save a function call, hence a stack frame
# see test_recursive*
early_exit = False
running_uncached = False
served_from_cache = False
stream_state = _StreamState.SETUP
try:
BackendCls = BACKENDS[C.backend]

Expand All @@ -550,7 +576,7 @@ def cachew_wrapper[**P, ItemT](
if new_hash == old_hash:
logger.debug('hash matched: loading from cache')
yield from session.cached_items()
served_from_cache = True
stream_state = _StreamState.FINISHED
return

logger.debug('hash mismatch: computing data and writing to db')
Expand All @@ -560,9 +586,9 @@ def cachew_wrapper[**P, ItemT](
# NOTE: this is the bit we really have to watch out for and not put in a helper function
# otherwise it's causing an extra stack frame on every call
# the rest (reading from cachew or writing to cachew) happens once per function call? so not a huge deal
running_uncached = True
stream_state = _StreamState.SOURCE
yield from func(*args, **kwargs)
running_uncached = False
stream_state = _StreamState.FINISHED
return

if synthetic_key is not None:
Expand All @@ -576,35 +602,54 @@ def cachew_wrapper[**P, ItemT](
kwargs[synthetic_key] = missing_synthetic_values # ty: ignore[invalid-assignment]

# at this point we're guaranteed to have an exclusive write transaction
fit = iter(func(*args, **kwargs))
try:
yield from session.write_to_cache(func(*args, **kwargs))
stream_state = _StreamState.CACHE_WRITE
total_objects = yield from session.write_items_to_cache(fit)
except GeneratorExit:
# GeneratorExit itself is not caught below, but SQLAlchemy cleanup during interpreter shutdown can raise a normal Exception while unwinding.
early_exit = True
raise
except CacheWriteError as e:
# If there is an error during marshalling, etc, we can't just reemit func(*args, **kwargs), we might end up with dupes
# so we try to switch back to the original iterator (fit) -- note it's reused/mutated iin write_items_to_cache
cachew_error(e, logger=logger)
stream_state = _StreamState.SOURCE
yield from fit
stream_state = _StreamState.FINISHED
return
stream_state = _StreamState.FINISHED
session.finalize_cache(total_objects=total_objects)
except CacheReadError:
# Cache read failures bypass THROW_ON_ERROR because fallback can duplicate already-yielded cached items.
# This can be thrown from session.cached_items()
raise
except Exception as e:
if running_uncached:
if stream_state is _StreamState.SOURCE:
# SOURCE means the wrapped function is already streaming uncached, so its exceptions must propagate unchanged.
raise

# Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors.
# See test_early_exit_shutdown.
if early_exit and 'Cannot operate on a closed database' in str(e):
return

if stream_state is _StreamState.CACHE_WRITE:
# CACHE_WRITE may have already yielded source items, so fallback could duplicate emitted output.
raise

cachew_error(e, logger=logger)

if served_from_cache:
# this can happen if we fully read from the cache, but hit some error while shutting backend down
# - we're past reading, so we emitted all items user wanted from cache
# - we don't want to yield any items from original func
# so it's safe to simply return
if stream_state is _StreamState.FINISHED:
# FINISHED means all requested items were emitted; cachew_error handles THROW_ON_ERROR, but fallback must not emit more data.
return

yield from func(*args, **kwargs)
if stream_state is _StreamState.SETUP:
# SETUP means no user-visible items have been yielded, so fallback is safe.
yield from func(*args, **kwargs)
return

assert_never(stream_state)


__all__ = [
Expand Down
8 changes: 8 additions & 0 deletions src/cachew/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class CacheReadError(CachewException):
pass


class CacheWriteError(CachewException):
"""
Internal signal for defensive cache write fallback.
"""

pass


@dataclass
class TypeNotSupported(CachewException):
type_: type
Expand Down
29 changes: 28 additions & 1 deletion src/cachew/tests/test_cachew.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,6 @@ def orig2():
assert list(fun()) == [123]


@pytest.mark.xfail(reason='cache write errors after yielding currently restart the source iterator', strict=True)
def test_defensive_write_error_after_yield_does_not_duplicate(
tmp_path: Path,
restore_settings,
Expand Down Expand Up @@ -1074,6 +1073,34 @@ def fun() -> Iterator[BB]:
assert calls == 1


def test_write_source_error_after_yield_propagates_without_retry(
tmp_path: Path,
restore_settings,
) -> None:
"""
If the wrapped iterator fails while cachew is writing, the source error must not trigger defensive retry.
"""
settings.THROW_ON_ERROR = False

class UserError(Exception):
pass

calls = 0

@cachew(tmp_path)
def fun() -> Iterator[int]:
nonlocal calls
calls += 1
yield 1
raise UserError('boom')

it = iter(fun())
assert next(it) == 1
with pytest.raises(UserError, match='boom'):
next(it)
assert calls == 1


def test_defensive_read_error_after_yield_raises_cache_read_error(
tmp_path: Path,
restore_settings,
Expand Down