Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 127 additions & 105 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
import stat
import warnings
from collections.abc import Callable, Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from types import TracebackType
from typing import (
Any,
Literal,
Protocol,
assert_never,
Self,
cast,
overload,
override,
)

try:
Expand Down Expand Up @@ -416,30 +418,20 @@ def composite_hash(self, *args, **kwargs) -> dict[str, Any]:
return hash_parts


class _StreamState(Enum):
"""
Streaming lifecycle for cachew_wrapper.
Allowed transitions:
SETUP -> SOURCE when cachew deliberately yields from the wrapped function without caching.
SETUP -> CACHE_WRITE after cachew gets an exclusive write transaction.
SETUP -> FINISHED after a complete cache hit.
CACHE_WRITE -> SOURCE after a recoverable cache write failure; the same source iterator resumes uncached.
CACHE_WRITE -> FINISHED after the source iterator is exhausted and all items have been yielded.
SOURCE -> FINISHED after uncached streaming completes.
FINISHED is terminal; cleanup errors may be logged or raised, but must not trigger fallback that emits more items.
"""

SETUP = auto()
SOURCE = auto()
CACHE_WRITE = auto()
FINISHED = auto()
type _SessionPhase = Literal['setup', 'streaming', 'completed']


@dataclass
class CacheSession[ItemT]:
class CacheSession[ItemT](AbstractContextManager):
"""
Per-call state for an open backend transaction.
This keeps read/write generator helpers out of cachew_wrapper while preserving the direct recursive fallback path.
Per-call cache/backend lifetime.

Allowed phase transitions:
setup -> streaming when cachew starts yielding cached or source items.
streaming -> completed after all requested items have been yielded.
streaming -> streaming is allowed because synthetic cache reads can happen after source streaming has already started.
completed is terminal; cleanup/finalize errors may be logged or raised, but must never trigger fallback.
setup is the only phase where wrapper-level fallback is safe because no user-visible items have been yielded.
"""

backend: AbstractBackend
Expand All @@ -449,6 +441,46 @@ class CacheSession[ItemT]:
new_hash: SourceHash
chunk_by: int
logger: logging.Logger
phase: _SessionPhase

@override
def __enter__(self) -> Self:
self.backend.__enter__()
return self

@override
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool | None:
try:
self.backend.__exit__(exc_type, exc, tb)
except Exception as e:
# Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors.
# See test_early_exit_shutdown.
if exc_type is GeneratorExit and 'Cannot operate on a closed database' in str(e):
# Swallow only the cleanup noise; returning None lets the original GeneratorExit keep closing the generator.
return None
raise
return None

def mark_streaming(self) -> None:
assert self.phase in ('setup', 'streaming'), self.phase
self.phase = 'streaming'

def mark_completed(self) -> None:
assert self.phase == 'streaming', self.phase
self.phase = 'completed'

@property
def fallback_allowed(self) -> bool:
return self.phase == 'setup'

@property
def completed(self) -> bool:
return self.phase == 'completed'

def cached_items(self) -> Iterator[ItemT]:
total_cached = self.backend.cached_blobs_total()
Expand All @@ -457,25 +489,40 @@ def cached_items(self) -> Iterator[ItemT]:
f'loading {total_cached_s}objects from cachew ({self.backend_name}:{self.resolved_cache_path})'
)

self.mark_streaming()
try:
for blob in self.backend.cached_blobs():
j = orjson_loads(blob)
obj = self.marshall.load(j)
yield obj
except Exception as e:
# Preserve the original exception as __cause__, but surface an actionable cachew-level error.
raise CacheReadError(
f'failed to read cachew cache ({self.backend_name}:{self.resolved_cache_path}); remove the cache and try again'
) from e

def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int]:
if isinstance(self.backend, FileBackend):
# FIXME uhhh.. this is a bit crap
# but in sqlite mode we don't want to publish new hash before we write new items
# maybe should use tmp table for hashes as well?
self.backend.write_new_hash(self.new_hash)
else:
# happens later for sqlite
pass
def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int | None]:
data_iter = iter(datas)

def cache_write_error(e: Exception) -> None:
msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})'
cache_error = CacheWriteError(msg)
cache_error.__cause__ = e
cachew_error(cache_error, logger=self.logger)

try:
if isinstance(self.backend, FileBackend):
# FIXME uhhh.. this is a bit crap
# but in sqlite mode we don't want to publish new hash before we write new items
# maybe should use tmp table for hashes as well?
self.backend.write_new_hash(self.new_hash)
else:
# happens later for sqlite
pass
except Exception as e:
cache_write_error(e)
yield from data_iter
return None

flush_blobs = self.backend.flush_blobs

Expand All @@ -488,20 +535,25 @@ def flush() -> None:
chunk = []

total_objects = 0
for obj in datas:
for obj in data_iter:
total_objects += 1
yield obj

try:
dct = self.marshall.dump(obj)
blob = orjson_dumps(dct)
chunk.append(blob)
if len(chunk) >= self.chunk_by:
flush()
except Exception as e:
msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})'
raise CacheWriteError(msg) from e
chunk.append(blob)
if len(chunk) >= self.chunk_by:
flush()
flush()
cache_write_error(e)
yield from data_iter
return None
try:
flush()
except Exception as e:
cache_write_error(e)
return None
return total_objects

def finalize_cache(self, *, total_objects: int) -> None:
Expand Down Expand Up @@ -545,111 +597,81 @@ def cachew_wrapper[**P, ItemT](

synthetic_key = C.synthetic_key

# NOTE: annoyingly huge try/catch ahead...
# but it lets us save a function call, hence a stack frame
# see test_recursive*
early_exit = False
stream_state = _StreamState.SETUP
session: CacheSession[ItemT] | None = None
try:
BackendCls = BACKENDS[C.backend]

new_hash_d = C.composite_hash(*args, **kwargs)
new_hash: SourceHash = json.dumps(new_hash_d)
logger.debug(f'new hash: {new_hash}')

marshall: CachewMarshall[ItemT] = CachewMarshall(Type_=C.cls_)

with BackendCls(cache_path=resolved_cache_path, logger=logger) as backend:
session = CacheSession(
backend=backend,
backend_name=C.backend,
resolved_cache_path=resolved_cache_path,
marshall=marshall,
new_hash=new_hash,
chunk_by=C.chunk_by,
logger=logger,
)
backend = BackendCls(cache_path=resolved_cache_path, logger=logger)
session = CacheSession(
backend=backend,
backend_name=C.backend,
resolved_cache_path=resolved_cache_path,
marshall=marshall,
new_hash=new_hash,
chunk_by=C.chunk_by,
logger=logger,
phase='setup',
)

with session:
old_hash = backend.get_old_hash()
logger.debug(f'old hash: {old_hash}')

if new_hash == old_hash:
logger.debug('hash matched: loading from cache')
yield from session.cached_items()
stream_state = _StreamState.FINISHED
session.mark_completed()
return

logger.debug('hash mismatch: computing data and writing to db')

got_write = backend.get_exclusive_write()
if not got_write:
# NOTE: this is the bit we really have to watch out for and not put in a helper function
# otherwise it's causing an extra stack frame on every call
# the rest (reading from cachew or writing to cachew) happens once per function call? so not a huge deal
stream_state = _StreamState.SOURCE
# NOTE: this is the bit we really have to watch out for and not put in a helper function.
# Otherwise it's causing an extra stack frame on every recursive call.
session.mark_streaming()
yield from func(*args, **kwargs)
stream_state = _StreamState.FINISHED
session.mark_completed()
return

source_kwargs: Any = kwargs
if synthetic_key is not None:
missing_synthetic_values = _synthetic.missing_synthetic_key_values_for_hashes(
old_hash=old_hash,
new_hash_d=new_hash_d,
)
if missing_synthetic_values is not None:
# can reuse cache
kwargs[_synthetic.CACHEW_CACHED] = session.cached_items() # ty: ignore[invalid-assignment]
kwargs[synthetic_key] = missing_synthetic_values # ty: ignore[invalid-assignment]
synthetic_kwargs = dict(kwargs)
synthetic_kwargs[_synthetic.CACHEW_CACHED] = session.cached_items()
synthetic_kwargs[synthetic_key] = missing_synthetic_values
source_kwargs = synthetic_kwargs

# at this point we're guaranteed to have an exclusive write transaction
fit = iter(func(*args, **kwargs))
try:
stream_state = _StreamState.CACHE_WRITE
total_objects = yield from session.write_items_to_cache(fit)
except GeneratorExit:
# GeneratorExit itself is not caught below, but SQLAlchemy cleanup during interpreter shutdown can raise a normal Exception while unwinding.
early_exit = True
raise
except CacheWriteError as e:
# If there is an error during marshalling, etc, we can't just reemit func(*args, **kwargs), we might end up with dupes
# so we try to switch back to the original iterator (fit) -- note it's reused/mutated iin write_items_to_cache
cachew_error(e, logger=logger)
stream_state = _StreamState.SOURCE
yield from fit
stream_state = _StreamState.FINISHED
session.mark_streaming()
fit = iter(func(*args, **source_kwargs))
total_objects = yield from session.write_items_to_cache(fit)
session.mark_completed()

if total_objects is None:
return
stream_state = _StreamState.FINISHED

session.finalize_cache(total_objects=total_objects)
except CacheReadError:
# Cache read failures bypass THROW_ON_ERROR because fallback can duplicate already-yielded cached items.
# This can be thrown from session.cached_items()
raise
except Exception as e:
if stream_state is _StreamState.SOURCE:
# SOURCE means the wrapped function is already streaming uncached, so its exceptions must propagate unchanged.
raise

# Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors.
# See test_early_exit_shutdown.
if early_exit and 'Cannot operate on a closed database' in str(e):
return

if stream_state is _StreamState.CACHE_WRITE:
# CACHE_WRITE may have already yielded source items, so fallback could duplicate emitted output.
raise

cachew_error(e, logger=logger)

if stream_state is _StreamState.FINISHED:
# FINISHED means all requested items were emitted; cachew_error handles THROW_ON_ERROR, but fallback must not emit more data.
if session is not None and session.completed:
# All requested items were emitted, so only report the cache cleanup/finalize error.
cachew_error(e, logger=logger)
return

if stream_state is _StreamState.SETUP:
# SETUP means no user-visible items have been yielded, so fallback is safe.
if session is None or session.fallback_allowed:
# No user-visible items were emitted yet, so a full uncached fallback cannot duplicate output.
cachew_error(e, logger=logger)
yield from func(*args, **kwargs)
return

assert_never(stream_state)
raise


__all__ = [
Expand Down
27 changes: 26 additions & 1 deletion src/cachew/tests/test_cachew.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def fuzz_cachew_impl():
from .. import cachew_wrapper

patch = '''\
@@ -189,6 +189,11 @@
@@ -76,6 +76,11 @@
old_hash = backend.get_old_hash()
logger.debug(f'old hash: {old_hash}')

Expand Down Expand Up @@ -1101,6 +1101,31 @@ def fun() -> Iterator[int]:
assert calls == 1


def test_write_source_call_error_propagates_without_retry(
tmp_path: Path,
restore_settings,
) -> None:
"""
If the wrapped function fails before returning an iterable, cachew must treat it as a source error.
"""
settings.THROW_ON_ERROR = False

class UserError(Exception):
pass

calls = 0

@cachew(tmp_path, cls=int)
def fun() -> list[int]:
nonlocal calls
calls += 1
raise UserError('boom')

with pytest.raises(UserError, match='boom'):
list(fun())
assert calls == 1


def test_defensive_read_error_after_yield_raises_cache_read_error(
tmp_path: Path,
restore_settings,
Expand Down