From 3b29c1e8654a9ca204171c18a0aaea85d95dd05d Mon Sep 17 00:00:00 2001 From: Dima Gerasimov Date: Tue, 26 May 2026 23:49:40 +0100 Subject: [PATCH] core: refactor error handling/state transitions should be way more straightforward now --- src/cachew/__init__.py | 232 +++++++++++++++++--------------- src/cachew/tests/test_cachew.py | 27 +++- 2 files changed, 153 insertions(+), 106 deletions(-) diff --git a/src/cachew/__init__.py b/src/cachew/__init__.py index cfb1b00..936aa8f 100644 --- a/src/cachew/__init__.py +++ b/src/cachew/__init__.py @@ -6,16 +6,18 @@ import stat import warnings from collections.abc import Callable, Generator, Iterable, Iterator +from contextlib import AbstractContextManager from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path +from types import TracebackType from typing import ( Any, Literal, Protocol, - assert_never, + Self, cast, overload, + override, ) try: @@ -416,30 +418,20 @@ def composite_hash(self, *args, **kwargs) -> dict[str, Any]: return hash_parts -class _StreamState(Enum): - """ - Streaming lifecycle for cachew_wrapper. - Allowed transitions: - SETUP -> SOURCE when cachew deliberately yields from the wrapped function without caching. - SETUP -> CACHE_WRITE after cachew gets an exclusive write transaction. - SETUP -> FINISHED after a complete cache hit. - CACHE_WRITE -> SOURCE after a recoverable cache write failure; the same source iterator resumes uncached. - CACHE_WRITE -> FINISHED after the source iterator is exhausted and all items have been yielded. - SOURCE -> FINISHED after uncached streaming completes. - FINISHED is terminal; cleanup errors may be logged or raised, but must not trigger fallback that emits more items. - """ - - SETUP = auto() - SOURCE = auto() - CACHE_WRITE = auto() - FINISHED = auto() +type _SessionPhase = Literal['setup', 'streaming', 'completed'] @dataclass -class CacheSession[ItemT]: +class CacheSession[ItemT](AbstractContextManager): """ - Per-call state for an open backend transaction. - This keeps read/write generator helpers out of cachew_wrapper while preserving the direct recursive fallback path. + Per-call cache/backend lifetime. + + Allowed phase transitions: + setup -> streaming when cachew starts yielding cached or source items. + streaming -> completed after all requested items have been yielded. + streaming -> streaming is allowed because synthetic cache reads can happen after source streaming has already started. + completed is terminal; cleanup/finalize errors may be logged or raised, but must never trigger fallback. + setup is the only phase where wrapper-level fallback is safe because no user-visible items have been yielded. """ backend: AbstractBackend @@ -449,6 +441,46 @@ class CacheSession[ItemT]: new_hash: SourceHash chunk_by: int logger: logging.Logger + phase: _SessionPhase + + @override + def __enter__(self) -> Self: + self.backend.__enter__() + return self + + @override + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: + try: + self.backend.__exit__(exc_type, exc, tb) + except Exception as e: + # Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors. + # See test_early_exit_shutdown. + if exc_type is GeneratorExit and 'Cannot operate on a closed database' in str(e): + # Swallow only the cleanup noise; returning None lets the original GeneratorExit keep closing the generator. + return None + raise + return None + + def mark_streaming(self) -> None: + assert self.phase in ('setup', 'streaming'), self.phase + self.phase = 'streaming' + + def mark_completed(self) -> None: + assert self.phase == 'streaming', self.phase + self.phase = 'completed' + + @property + def fallback_allowed(self) -> bool: + return self.phase == 'setup' + + @property + def completed(self) -> bool: + return self.phase == 'completed' def cached_items(self) -> Iterator[ItemT]: total_cached = self.backend.cached_blobs_total() @@ -457,25 +489,40 @@ def cached_items(self) -> Iterator[ItemT]: f'loading {total_cached_s}objects from cachew ({self.backend_name}:{self.resolved_cache_path})' ) + self.mark_streaming() try: for blob in self.backend.cached_blobs(): j = orjson_loads(blob) obj = self.marshall.load(j) yield obj except Exception as e: + # Preserve the original exception as __cause__, but surface an actionable cachew-level error. raise CacheReadError( f'failed to read cachew cache ({self.backend_name}:{self.resolved_cache_path}); remove the cache and try again' ) from e - def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int]: - if isinstance(self.backend, FileBackend): - # FIXME uhhh.. this is a bit crap - # but in sqlite mode we don't want to publish new hash before we write new items - # maybe should use tmp table for hashes as well? - self.backend.write_new_hash(self.new_hash) - else: - # happens later for sqlite - pass + def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int | None]: + data_iter = iter(datas) + + def cache_write_error(e: Exception) -> None: + msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})' + cache_error = CacheWriteError(msg) + cache_error.__cause__ = e + cachew_error(cache_error, logger=self.logger) + + try: + if isinstance(self.backend, FileBackend): + # FIXME uhhh.. this is a bit crap + # but in sqlite mode we don't want to publish new hash before we write new items + # maybe should use tmp table for hashes as well? + self.backend.write_new_hash(self.new_hash) + else: + # happens later for sqlite + pass + except Exception as e: + cache_write_error(e) + yield from data_iter + return None flush_blobs = self.backend.flush_blobs @@ -488,20 +535,25 @@ def flush() -> None: chunk = [] total_objects = 0 - for obj in datas: + for obj in data_iter: total_objects += 1 yield obj try: dct = self.marshall.dump(obj) blob = orjson_dumps(dct) + chunk.append(blob) + if len(chunk) >= self.chunk_by: + flush() except Exception as e: - msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})' - raise CacheWriteError(msg) from e - chunk.append(blob) - if len(chunk) >= self.chunk_by: - flush() - flush() + cache_write_error(e) + yield from data_iter + return None + try: + flush() + except Exception as e: + cache_write_error(e) + return None return total_objects def finalize_cache(self, *, total_objects: int) -> None: @@ -545,111 +597,81 @@ def cachew_wrapper[**P, ItemT]( synthetic_key = C.synthetic_key - # NOTE: annoyingly huge try/catch ahead... - # but it lets us save a function call, hence a stack frame - # see test_recursive* - early_exit = False - stream_state = _StreamState.SETUP + session: CacheSession[ItemT] | None = None try: BackendCls = BACKENDS[C.backend] - new_hash_d = C.composite_hash(*args, **kwargs) new_hash: SourceHash = json.dumps(new_hash_d) logger.debug(f'new hash: {new_hash}') - marshall: CachewMarshall[ItemT] = CachewMarshall(Type_=C.cls_) - with BackendCls(cache_path=resolved_cache_path, logger=logger) as backend: - session = CacheSession( - backend=backend, - backend_name=C.backend, - resolved_cache_path=resolved_cache_path, - marshall=marshall, - new_hash=new_hash, - chunk_by=C.chunk_by, - logger=logger, - ) + backend = BackendCls(cache_path=resolved_cache_path, logger=logger) + session = CacheSession( + backend=backend, + backend_name=C.backend, + resolved_cache_path=resolved_cache_path, + marshall=marshall, + new_hash=new_hash, + chunk_by=C.chunk_by, + logger=logger, + phase='setup', + ) + with session: old_hash = backend.get_old_hash() logger.debug(f'old hash: {old_hash}') if new_hash == old_hash: logger.debug('hash matched: loading from cache') yield from session.cached_items() - stream_state = _StreamState.FINISHED + session.mark_completed() return logger.debug('hash mismatch: computing data and writing to db') got_write = backend.get_exclusive_write() if not got_write: - # NOTE: this is the bit we really have to watch out for and not put in a helper function - # otherwise it's causing an extra stack frame on every call - # the rest (reading from cachew or writing to cachew) happens once per function call? so not a huge deal - stream_state = _StreamState.SOURCE + # NOTE: this is the bit we really have to watch out for and not put in a helper function. + # Otherwise it's causing an extra stack frame on every recursive call. + session.mark_streaming() yield from func(*args, **kwargs) - stream_state = _StreamState.FINISHED + session.mark_completed() return + source_kwargs: Any = kwargs if synthetic_key is not None: missing_synthetic_values = _synthetic.missing_synthetic_key_values_for_hashes( old_hash=old_hash, new_hash_d=new_hash_d, ) if missing_synthetic_values is not None: - # can reuse cache - kwargs[_synthetic.CACHEW_CACHED] = session.cached_items() # ty: ignore[invalid-assignment] - kwargs[synthetic_key] = missing_synthetic_values # ty: ignore[invalid-assignment] + synthetic_kwargs = dict(kwargs) + synthetic_kwargs[_synthetic.CACHEW_CACHED] = session.cached_items() + synthetic_kwargs[synthetic_key] = missing_synthetic_values + source_kwargs = synthetic_kwargs - # at this point we're guaranteed to have an exclusive write transaction - fit = iter(func(*args, **kwargs)) - try: - stream_state = _StreamState.CACHE_WRITE - total_objects = yield from session.write_items_to_cache(fit) - except GeneratorExit: - # GeneratorExit itself is not caught below, but SQLAlchemy cleanup during interpreter shutdown can raise a normal Exception while unwinding. - early_exit = True - raise - except CacheWriteError as e: - # If there is an error during marshalling, etc, we can't just reemit func(*args, **kwargs), we might end up with dupes - # so we try to switch back to the original iterator (fit) -- note it's reused/mutated iin write_items_to_cache - cachew_error(e, logger=logger) - stream_state = _StreamState.SOURCE - yield from fit - stream_state = _StreamState.FINISHED + session.mark_streaming() + fit = iter(func(*args, **source_kwargs)) + total_objects = yield from session.write_items_to_cache(fit) + session.mark_completed() + + if total_objects is None: return - stream_state = _StreamState.FINISHED + session.finalize_cache(total_objects=total_objects) - except CacheReadError: - # Cache read failures bypass THROW_ON_ERROR because fallback can duplicate already-yielded cached items. - # This can be thrown from session.cached_items() - raise except Exception as e: - if stream_state is _StreamState.SOURCE: - # SOURCE means the wrapped function is already streaming uncached, so its exceptions must propagate unchanged. - raise - - # Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors. - # See test_early_exit_shutdown. - if early_exit and 'Cannot operate on a closed database' in str(e): - return - - if stream_state is _StreamState.CACHE_WRITE: - # CACHE_WRITE may have already yielded source items, so fallback could duplicate emitted output. - raise - - cachew_error(e, logger=logger) - - if stream_state is _StreamState.FINISHED: - # FINISHED means all requested items were emitted; cachew_error handles THROW_ON_ERROR, but fallback must not emit more data. + if session is not None and session.completed: + # All requested items were emitted, so only report the cache cleanup/finalize error. + cachew_error(e, logger=logger) return - if stream_state is _StreamState.SETUP: - # SETUP means no user-visible items have been yielded, so fallback is safe. + if session is None or session.fallback_allowed: + # No user-visible items were emitted yet, so a full uncached fallback cannot duplicate output. + cachew_error(e, logger=logger) yield from func(*args, **kwargs) return - assert_never(stream_state) + raise __all__ = [ diff --git a/src/cachew/tests/test_cachew.py b/src/cachew/tests/test_cachew.py index 1d12b49..19455c1 100644 --- a/src/cachew/tests/test_cachew.py +++ b/src/cachew/tests/test_cachew.py @@ -926,7 +926,7 @@ def fuzz_cachew_impl(): from .. import cachew_wrapper patch = '''\ -@@ -189,6 +189,11 @@ +@@ -76,6 +76,11 @@ old_hash = backend.get_old_hash() logger.debug(f'old hash: {old_hash}') @@ -1101,6 +1101,31 @@ def fun() -> Iterator[int]: assert calls == 1 +def test_write_source_call_error_propagates_without_retry( + tmp_path: Path, + restore_settings, +) -> None: + """ + If the wrapped function fails before returning an iterable, cachew must treat it as a source error. + """ + settings.THROW_ON_ERROR = False + + class UserError(Exception): + pass + + calls = 0 + + @cachew(tmp_path, cls=int) + def fun() -> list[int]: + nonlocal calls + calls += 1 + raise UserError('boom') + + with pytest.raises(UserError, match='boom'): + list(fun()) + assert calls == 1 + + def test_defensive_read_error_after_yield_raises_cache_read_error( tmp_path: Path, restore_settings,