diff --git a/src/gcache/__init__.py b/src/gcache/__init__.py index 1a0b2eb..ebdd287 100644 --- a/src/gcache/__init__.py +++ b/src/gcache/__init__.py @@ -1,4 +1,29 @@ -from gcache.config import CacheConfigProvider, CacheLayer, GCacheConfig, GCacheKey, GCacheKeyConfig, RedisConfig +from gcache.config import ( + CacheCallContext, + CacheConfigProvider, + CacheHitDecision, + CacheHitHook, + CacheLayer, + EvictAndFallback, + GCacheConfig, + GCacheKey, + GCacheKeyConfig, + RedisConfig, + ReturnCached, +) from gcache.gcache import GCache -__all__ = ["CacheConfigProvider", "CacheLayer", "GCache", "GCacheConfig", "GCacheKey", "GCacheKeyConfig", "RedisConfig"] +__all__ = [ + "CacheCallContext", + "CacheConfigProvider", + "CacheHitDecision", + "CacheHitHook", + "CacheLayer", + "EvictAndFallback", + "GCache", + "GCacheConfig", + "GCacheKey", + "GCacheKeyConfig", + "RedisConfig", + "ReturnCached", +] diff --git a/src/gcache/_internal/cache_hit.py b/src/gcache/_internal/cache_hit.py new file mode 100644 index 0000000..217d2d3 --- /dev/null +++ b/src/gcache/_internal/cache_hit.py @@ -0,0 +1,89 @@ +import inspect +from dataclasses import dataclass +from typing import Any + +from gcache._internal.metrics import GCacheMetrics +from gcache._internal.state import _GLOBAL_GCACHE_STATE +from gcache.config import CacheCallContext, CacheHitHook, CacheLayer, EvictAndFallback, GCacheKey, ReturnCached + + +@dataclass(frozen=True, slots=True) +class BypassCurrentLayer: + """ + Internal signal to ignore this layer for the current request only. + + This is used when the hook machinery is unreliable rather than the cached + value itself being known-bad. Two cases currently map here: + - the hook raised an exception + - the hook returned an unsupported decision type + + In those situations we want fail-open behavior: + - do not fail the caller request + - do not evict the cached value, because the problem may be in the hook + - continue through the normal fallback chain as if this layer had no usable hit + """ + + +async def run_cache_hit_hook( + *, + key: GCacheKey, + layer: CacheLayer, + value: Any, + on_cache_hit: CacheHitHook | None, +) -> ReturnCached | EvictAndFallback | BypassCurrentLayer: + """ + Execute the optional cache-hit hook and normalize the result. + + Returning `BypassCurrentLayer` is reserved for hook execution/contract + failures. It lets cache layers skip a hit without deleting it, which keeps + hook bugs from turning into request failures or unnecessary evictions. + """ + if on_cache_hit is None: + return ReturnCached() + + context = CacheCallContext(key=key, layer=layer) + + try: + decision = on_cache_hit(context, value) + if inspect.isawaitable(decision): + decision = await decision + except Exception: + _GLOBAL_GCACHE_STATE.logger.error( + "Error executing cache hit hook", + extra={"use_case": key.use_case, "key_type": key.key_type, "layer": layer.name}, + exc_info=True, + ) + GCacheMetrics.HIT_HOOK_ERROR_COUNTER.labels(key.use_case, key.key_type, layer.name).inc() + return BypassCurrentLayer() + + if isinstance(decision, ReturnCached): + GCacheMetrics.HIT_HOOK_ACTION_COUNTER.labels( + key.use_case, + key.key_type, + layer.name, + "return", + "none", + ).inc() + return decision + + if isinstance(decision, EvictAndFallback): + GCacheMetrics.HIT_HOOK_ACTION_COUNTER.labels( + key.use_case, + key.key_type, + layer.name, + "evict", + decision.reason or "none", + ).inc() + return decision + + _GLOBAL_GCACHE_STATE.logger.error( + "Cache hit hook returned invalid decision type", + extra={ + "use_case": key.use_case, + "key_type": key.key_type, + "layer": layer.name, + "decision_type": type(decision).__name__, + }, + ) + GCacheMetrics.HIT_HOOK_ERROR_COUNTER.labels(key.use_case, key.key_type, layer.name).inc() + return BypassCurrentLayer() diff --git a/src/gcache/_internal/cache_interface.py b/src/gcache/_internal/cache_interface.py index 0f6495e..ff3abc5 100644 --- a/src/gcache/_internal/cache_interface.py +++ b/src/gcache/_internal/cache_interface.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable from typing import Any -from gcache.config import CacheConfigProvider, CacheLayer, GCacheKey, GCacheKeyConfig +from gcache.config import CacheConfigProvider, CacheHitHook, CacheLayer, GCacheKey, GCacheKeyConfig #: Async callable that fetches the actual value on cache miss. #: Invoked by cache implementations when the requested key is not found or is stale. @@ -26,7 +26,13 @@ async def _resolve_config(self, key: GCacheKey) -> GCacheKeyConfig | None: return config @abstractmethod - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: pass @abstractmethod diff --git a/src/gcache/_internal/local_cache.py b/src/gcache/_internal/local_cache.py index 899c968..e123bb7 100644 --- a/src/gcache/_internal/local_cache.py +++ b/src/gcache/_internal/local_cache.py @@ -3,10 +3,11 @@ from cachetools import TTLCache +from gcache._internal.cache_hit import BypassCurrentLayer, run_cache_hit_hook from gcache._internal.cache_interface import CacheInterface, Fallback from gcache._internal.constants import LOCAL_CACHE_MAX_SIZE from gcache._internal.state import _GLOBAL_GCACHE_STATE -from gcache.config import CacheConfigProvider, CacheLayer, GCacheKey +from gcache.config import CacheConfigProvider, CacheHitHook, CacheLayer, EvictAndFallback, GCacheKey, ReturnCached from gcache.exceptions import MissingKeyConfig @@ -44,14 +45,41 @@ async def _get_ttl_cache(self, key: GCacheKey) -> TTLCache: return cache - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def _exec_fallback(self, key: GCacheKey, fallback: Fallback) -> Any: + value = await fallback() + await self.put(key, value) + return value + + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: _GLOBAL_GCACHE_STATE.logger.debug("Calling local cache") cache = await self._get_ttl_cache(key) if key not in cache: - await self.put(key, await fallback()) - - return cache[key] + return await self._exec_fallback(key, fallback) + + cached_value = cache[key] + + decision = await run_cache_hit_hook( + key=key, + layer=self.layer(), + value=cached_value, + on_cache_hit=on_cache_hit, + ) + if isinstance(decision, ReturnCached): + return cached_value + if isinstance(decision, EvictAndFallback): + cache.pop(key, None) + return await self._exec_fallback(key, fallback) + if isinstance(decision, BypassCurrentLayer): + return await fallback() + + return cached_value async def put(self, key: GCacheKey, value: Any) -> None: (await self._get_ttl_cache(key))[key] = value diff --git a/src/gcache/_internal/metrics.py b/src/gcache/_internal/metrics.py index 6e7965d..39755e9 100644 --- a/src/gcache/_internal/metrics.py +++ b/src/gcache/_internal/metrics.py @@ -12,6 +12,8 @@ class GCacheMetrics: REQUEST_COUNTER: Counter ERROR_COUNTER: Counter INVALIDATION_COUNTER: Counter + HIT_HOOK_ACTION_COUNTER: Counter + HIT_HOOK_ERROR_COUNTER: Counter # Histograms GET_TIMER: Histogram @@ -55,6 +57,18 @@ def initialize(cls, prefix: str = "") -> None: documentation="Cache invalidation counter", ) + cls.HIT_HOOK_ACTION_COUNTER = Counter( + name=prefix + "gcache_hit_hook_action_counter", + labelnames=["use_case", "key_type", "layer", "action", "reason"], + documentation="Cache hit hook action counter", + ) + + cls.HIT_HOOK_ERROR_COUNTER = Counter( + name=prefix + "gcache_hit_hook_error_counter", + labelnames=["use_case", "key_type", "layer"], + documentation="Cache hit hook error counter", + ) + cls.GET_TIMER = Histogram( name=prefix + "gcache_get_timer", labelnames=["use_case", "key_type", "layer"], diff --git a/src/gcache/_internal/noop_cache.py b/src/gcache/_internal/noop_cache.py index e8aad59..01164cf 100644 --- a/src/gcache/_internal/noop_cache.py +++ b/src/gcache/_internal/noop_cache.py @@ -1,7 +1,7 @@ from typing import Any from gcache._internal.cache_interface import CacheInterface, Fallback -from gcache.config import CacheLayer, GCacheKey +from gcache.config import CacheHitHook, CacheLayer, GCacheKey class NoopCache(CacheInterface): @@ -9,7 +9,13 @@ class NoopCache(CacheInterface): NOOP Cache that does nothing but invoke fallback on get. """ - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: return await fallback() async def put(self, key: GCacheKey, value: Any) -> None: diff --git a/src/gcache/_internal/redis_cache.py b/src/gcache/_internal/redis_cache.py index da2510f..47f6609 100644 --- a/src/gcache/_internal/redis_cache.py +++ b/src/gcache/_internal/redis_cache.py @@ -9,11 +9,20 @@ from redis.asyncio import Redis, RedisCluster +from gcache._internal.cache_hit import BypassCurrentLayer, run_cache_hit_hook from gcache._internal.cache_interface import CacheInterface, Fallback from gcache._internal.constants import ASYNC_PICKLE_THRESHOLD_BYTES, WATERMARK_TTL_SECONDS from gcache._internal.metrics import GCacheMetrics from gcache._internal.state import _GLOBAL_GCACHE_STATE -from gcache.config import CacheConfigProvider, CacheLayer, GCacheKey, RedisConfig +from gcache.config import ( + CacheConfigProvider, + CacheHitHook, + CacheLayer, + EvictAndFallback, + GCacheKey, + RedisConfig, + ReturnCached, +) from gcache.exceptions import MissingKeyConfig @@ -128,7 +137,13 @@ async def _async_pickle_loads(data: bytes) -> Any: loop = asyncio.get_event_loop() return await loop.run_in_executor(RedisCache._executor, pickle.loads, data) - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: _GLOBAL_GCACHE_STATE.logger.debug("Calling Redis Cache") watermark_ms = None @@ -142,6 +157,9 @@ async def get(self, key: GCacheKey, fallback: Fallback) -> Any: val_pickle = await self.client.get(key.urn) if val_pickle is not None: start_sec = time.monotonic() + serialization_timer = GCacheMetrics.SERIALIZATION_TIMER.labels( + key.use_case, key.key_type, self.layer().name, "load" + ) deserialized_value: RedisValue = ( pickle.loads(val_pickle) @@ -149,22 +167,34 @@ async def get(self, key: GCacheKey, fallback: Fallback) -> Any: else await RedisCache._async_pickle_loads(val_pickle) ) + # Ignore invalidated remote entries before payload deserialization or hook execution. + if watermark_ms is not None: + watermark_ms = int(watermark_ms) + if watermark_ms >= deserialized_value.created_at_ms: + serialization_timer.observe(time.monotonic() - start_sec) + return await self._exec_fallback(key, watermark_ms, fallback) + # Load payload using custom serializer if present. payload = deserialized_value.payload if key.serializer is not None: payload = await key.serializer.load(payload) - ( - GCacheMetrics.SERIALIZATION_TIMER.labels(key.use_case, key.key_type, self.layer().name, "load").observe( - time.monotonic() - start_sec - ) + serialization_timer.observe(time.monotonic() - start_sec) + + decision = await run_cache_hit_hook( + key=key, + layer=self.layer(), + value=payload, + on_cache_hit=on_cache_hit, ) + if isinstance(decision, EvictAndFallback): + await self.delete(key) + return await self._exec_fallback(key, watermark_ms, fallback) + if isinstance(decision, BypassCurrentLayer): + return await fallback() + if not isinstance(decision, ReturnCached): + return await fallback() - # Check if cache val is expired. - if watermark_ms is not None: - watermark_ms = int(watermark_ms) - if watermark_ms >= deserialized_value.created_at_ms: - return await self._exec_fallback(key, watermark_ms, fallback) return payload else: return await self._exec_fallback(key, watermark_ms, fallback) diff --git a/src/gcache/_internal/wrappers.py b/src/gcache/_internal/wrappers.py index 52b2388..5937e9f 100644 --- a/src/gcache/_internal/wrappers.py +++ b/src/gcache/_internal/wrappers.py @@ -6,7 +6,7 @@ from gcache._internal.cache_interface import CacheInterface, Fallback from gcache._internal.metrics import GCacheMetrics from gcache._internal.state import _GLOBAL_GCACHE_STATE, GCacheContext -from gcache.config import CacheConfigProvider, CacheLayer, GCacheKey +from gcache.config import CacheConfigProvider, CacheHitHook, CacheLayer, GCacheKey class CacheWrapper(CacheInterface): @@ -58,7 +58,13 @@ def __init__( super().__init__(cache_config_provider, cache) GCacheMetrics.initialize(metrics_prefix) - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: if await self._should_cache(key): start_time = time.monotonic() fallback_time = 0.0 @@ -84,7 +90,11 @@ async def instrumented_fallback() -> Any: ) try: - return await self.wrapped.get(key, instrumented_fallback) + return await self.wrapped.get( + key, + instrumented_fallback, + on_cache_hit=on_cache_hit, + ) except Exception as e: _GLOBAL_GCACHE_STATE.logger.error(f"Error getting value from cache: {e}", exc_info=True) GCacheMetrics.ERROR_COUNTER.labels( @@ -161,11 +171,17 @@ def __init__( super().__init__(cache_config_provider, cache) self.fallback_cache = fallback_cache - async def get(self, key: GCacheKey, fallback: Fallback) -> Any: + async def get( + self, + key: GCacheKey, + fallback: Fallback, + *, + on_cache_hit: CacheHitHook | None = None, + ) -> Any: async def cache_fallback() -> Any: - return await self.fallback_cache.get(key, fallback) + return await self.fallback_cache.get(key, fallback, on_cache_hit=on_cache_hit) - return await self.wrapped.get(key, cache_fallback) + return await self.wrapped.get(key, cache_fallback, on_cache_hit=on_cache_hit) async def delete(self, key: GCacheKey) -> bool: ret = await self.wrapped.delete(key) diff --git a/src/gcache/config.py b/src/gcache/config.py index 2cf6322..abb37c6 100644 --- a/src/gcache/config.py +++ b/src/gcache/config.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from enum import Enum from logging import Logger, LoggerAdapter -from typing import Any, Union +from typing import Any, TypeAlias, Union from pydantic import BaseModel, ConfigDict, field_validator from redis.asyncio import Redis, RedisCluster @@ -128,6 +128,38 @@ async def load(self, data: bytes | str) -> Any: pass +@dataclass(frozen=True, slots=True) +class CacheCallContext: + """ + Context provided to cache-hit hooks. + + The hook gets the cache key and the layer that produced the hit. + """ + + key: "GCacheKey" + layer: CacheLayer + + +@dataclass(frozen=True, slots=True) +class ReturnCached: + """Return the current cached value unchanged.""" + + +@dataclass(frozen=True, slots=True) +class EvictAndFallback: + """ + Evict the current layer's cached value and continue fallback resolution. + + The reason is intended for low-cardinality logging and metrics. + """ + + reason: str | None = None + + +CacheHitDecision: TypeAlias = ReturnCached | EvictAndFallback +CacheHitHook: TypeAlias = Callable[[CacheCallContext, Any], CacheHitDecision | Awaitable[CacheHitDecision]] + + @dataclass(frozen=True, slots=True) class GCacheKey: key_type: str @@ -210,6 +242,7 @@ class GCacheConfig(BaseModel): metrics_prefix: str = "api_" redis_config: RedisConfig | None = None redis_client_factory: Callable[[], Redis | RedisCluster] | None = None + on_cache_hit: CacheHitHook | None = None logger: Logger | LoggerAdapter | None = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/gcache/gcache.py b/src/gcache/gcache.py index 88507c3..5c2d96d 100644 --- a/src/gcache/gcache.py +++ b/src/gcache/gcache.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable, Generator from contextlib import contextmanager from functools import partial -from typing import Any +from typing import Any, Literal from gcache._internal.event_loop_thread import EventLoopThread, EventLoopThreadPool from gcache._internal.local_cache import LocalCache @@ -15,6 +15,7 @@ from gcache._internal.state import _GLOBAL_GCACHE_STATE, GCacheContext from gcache._internal.wrappers import CacheChain, CacheController, DisabledReasons from gcache.config import ( + CacheHitHook, GCacheConfig, GCacheKey, GCacheKeyConfig, @@ -30,6 +31,13 @@ ) +class _UseGlobalOnCacheHit: + """Private sentinel for decorator hooks that inherit the global config hook.""" + + +_USE_GLOBAL_ON_CACHE_HIT = _UseGlobalOnCacheHit() + + class GCache: """ Main entry point for the GCache caching library. @@ -154,6 +162,7 @@ def cached( track_for_invalidation: bool = False, default_config: GCacheKeyConfig | None = None, serializer: Serializer | None = None, + on_cache_hit: CacheHitHook | Literal[False] | _UseGlobalOnCacheHit = _USE_GLOBAL_ON_CACHE_HIT, ) -> Any: """ Decorator which caches a function which can be either sync or async. @@ -175,6 +184,9 @@ def cached( :param serializer: Optional serializer to use to serialize and deserialize cache values. Care must be taken that the returned value matches the signature of cached function, as otherwise you may get runtime type/attribute errors. + :param on_cache_hit: Optional cache-hit hook for this use case. If omitted, the global + hook from GCacheConfig is used. Pass False to explicitly disable the + global hook for this decorator. :return: """ @@ -205,6 +217,7 @@ def decorator(func: Any) -> Any: adapter_for_key = not isinstance(id_arg, str) id_arg_name = id_arg[0] if adapter_for_key else id_arg + resolved_on_cache_hit = self._resolve_on_cache_hit(on_cache_hit) # If name of id arg is in arg_adapters then we should include it in the cache key args. # Otherwise we should ignore it. @@ -288,7 +301,7 @@ async def async_wrapped(*args: Any, **kwargs: Any) -> Any: async def f(): # type: ignore[no-untyped-def, misc] return func(*args, **kwargs) - return await self._cache.get(key, f) + return await self._cache.get(key, f, on_cache_hit=resolved_on_cache_hit) if inspect.iscoroutinefunction(func): return functools.wraps(func)(async_wrapped) @@ -310,6 +323,17 @@ def sync_wrapped(*args: Any, **kwargs: Any) -> Any: return decorator + def _resolve_on_cache_hit( + self, on_cache_hit: CacheHitHook | Literal[False] | _UseGlobalOnCacheHit + ) -> CacheHitHook | None: + if on_cache_hit is _USE_GLOBAL_ON_CACHE_HIT: + return self.config.on_cache_hit + if on_cache_hit is False: + return None + if callable(on_cache_hit): + return on_cache_hit + raise TypeError("on_cache_hit must be callable, False, or omitted to use the global config hook") + async def ainvalidate(self, key_type: str, id: str, future_buffer_ms: int = 0) -> None: """ Invalidate all cache entries matching the given key type and ID (async version). diff --git a/tests/test_gcache.py b/tests/test_gcache.py index 5ec515d..ee320bd 100644 --- a/tests/test_gcache.py +++ b/tests/test_gcache.py @@ -1,10 +1,11 @@ +import asyncio import json import logging import pickle import threading from collections.abc import Generator from random import random -from typing import Any +from typing import Any, cast import pytest import redislite @@ -12,11 +13,13 @@ from gcache import ( CacheLayer, + EvictAndFallback, GCache, GCacheConfig, GCacheKey, GCacheKeyConfig, RedisConfig, + ReturnCached, ) from gcache._internal.cache_interface import Fallback from gcache._internal.local_cache import LocalCache @@ -41,6 +44,10 @@ def get_func_metric(name: str) -> float: return 0 +def get_local_cache_for_use_case(gcache: GCache, use_case: str) -> LocalCache: + return cast(LocalCache, gcache._local_cache.wrapped) + + def test_gcache_sync(gcache: GCache, redis_server: redislite.Redis, reset_prometheus_registry: Generator) -> None: v: int = 0 @@ -81,6 +88,328 @@ def cached_func(test: int = 123) -> int: assert len(keys) == 2 +def test_global_on_cache_hit_inherited( + cache_config_provider: FakeCacheConfigProvider, + redis_server: redislite.Redis, + reset_global_state: None, +) -> None: + redis_server.flushall() + calls: list[CacheLayer] = [] + + def global_hook(ctx, value): # type: ignore[no-untyped-def] + calls.append(ctx.layer) + return ReturnCached() + + gcache = GCache( + GCacheConfig( + cache_config_provider=cache_config_provider, + urn_prefix="urn:galileo:test", + redis_config=RedisConfig(port=REDIS_PORT), + on_cache_hit=global_hook, + ) + ) + try: + + @gcache.cached(key_type="Test", id_arg="test", use_case="test_global_on_cache_hit_inherited") + def cached_func(test: int = 123) -> int: + return test + + with gcache.enable(): + assert cached_func() == 123 + assert calls == [] + assert cached_func() == 123 + + assert calls == [CacheLayer.LOCAL] + finally: + gcache.__del__() + + +def test_decorator_on_cache_hit_overrides_global( + cache_config_provider: FakeCacheConfigProvider, + redis_server: redislite.Redis, + reset_global_state: None, +) -> None: + redis_server.flushall() + global_calls: list[CacheLayer] = [] + decorator_calls: list[CacheLayer] = [] + + def global_hook(ctx, value): # type: ignore[no-untyped-def] + global_calls.append(ctx.layer) + return ReturnCached() + + def decorator_hook(ctx, value): # type: ignore[no-untyped-def] + decorator_calls.append(ctx.layer) + return ReturnCached() + + gcache = GCache( + GCacheConfig( + cache_config_provider=cache_config_provider, + urn_prefix="urn:galileo:test", + redis_config=RedisConfig(port=REDIS_PORT), + on_cache_hit=global_hook, + ) + ) + try: + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_decorator_on_cache_hit_overrides_global", + on_cache_hit=decorator_hook, + ) + def cached_func(test: int = 123) -> int: + return test + + with gcache.enable(): + assert cached_func() == 123 + assert cached_func() == 123 + + assert global_calls == [] + assert decorator_calls == [CacheLayer.LOCAL] + finally: + gcache.__del__() + + +def test_decorator_can_disable_global_on_cache_hit( + cache_config_provider: FakeCacheConfigProvider, + redis_server: redislite.Redis, + reset_global_state: None, +) -> None: + redis_server.flushall() + global_calls: list[CacheLayer] = [] + + def global_hook(ctx, value): # type: ignore[no-untyped-def] + global_calls.append(ctx.layer) + return ReturnCached() + + gcache = GCache( + GCacheConfig( + cache_config_provider=cache_config_provider, + urn_prefix="urn:galileo:test", + redis_config=RedisConfig(port=REDIS_PORT), + on_cache_hit=global_hook, + ) + ) + try: + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_decorator_can_disable_global_on_cache_hit", + on_cache_hit=False, + ) + def cached_func(test: int = 123) -> int: + return test + + with gcache.enable(): + assert cached_func() == 123 + assert cached_func() == 123 + + assert global_calls == [] + finally: + gcache.__del__() + + +def test_local_on_cache_hit_evicts_and_falls_back_to_remote(gcache: GCache) -> None: + source_calls = 0 + hook_layers: list[CacheLayer] = [] + + def hook(ctx, value): # type: ignore[no-untyped-def] + hook_layers.append(ctx.layer) + if ctx.layer == CacheLayer.LOCAL and value["status"] == "bad": + return EvictAndFallback("bad_local_payload") + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_local_on_cache_hit_evicts_and_falls_back_to_remote", + on_cache_hit=hook, + ) + def cached_func(test: int) -> dict[str, object]: + nonlocal source_calls + source_calls += 1 + return {"status": "good", "source_calls": source_calls} + + with gcache.enable(): + assert cached_func(test=1) == {"status": "good", "source_calls": 1} + + local_cache = get_local_cache_for_use_case(gcache, "test_local_on_cache_hit_evicts_and_falls_back_to_remote") + cached_entries = local_cache.caches["test_local_on_cache_hit_evicts_and_falls_back_to_remote"] + cached_value = next(iter(cached_entries.values())) + cached_value["status"] = "bad" + + assert cached_func(test=1) == {"status": "good", "source_calls": 1} + assert cached_func(test=1) == {"status": "good", "source_calls": 1} + + assert source_calls == 1 + assert hook_layers == [CacheLayer.LOCAL, CacheLayer.REMOTE, CacheLayer.LOCAL] + + +def test_local_on_cache_hit_returns_validated_value_without_reread(gcache: GCache) -> None: + use_case = "test_local_on_cache_hit_returns_validated_value_without_reread" + hook_calls = 0 + + def hook(ctx, value): # type: ignore[no-untyped-def] + nonlocal hook_calls + hook_calls += 1 + if ctx.layer == CacheLayer.LOCAL: + local_cache = get_local_cache_for_use_case(gcache, use_case) + local_cache.caches[use_case][ctx.key] = {"value": "replaced"} + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case=use_case, + on_cache_hit=hook, + ) + def cached_func(test: int) -> dict[str, str]: + return {"value": "cached"} + + with gcache.enable(): + assert cached_func(test=1) == {"value": "cached"} + assert cached_func(test=1) == {"value": "cached"} + assert cached_func(test=1) == {"value": "replaced"} + + assert hook_calls == 2 + + +def test_remote_on_cache_hit_evicts_and_falls_back_to_source( + gcache: GCache, + cache_config_provider: FakeCacheConfigProvider, +) -> None: + source_calls = 0 + remote_hit_enabled = False + + cache_config_provider.configs["test_remote_on_cache_hit_evicts_and_falls_back_to_source"] = GCacheKeyConfig( + ttl_sec={CacheLayer.LOCAL: 60, CacheLayer.REMOTE: 60}, + ramp={CacheLayer.LOCAL: 0, CacheLayer.REMOTE: 100}, + ) + + def hook(ctx, value): # type: ignore[no-untyped-def] + if ctx.layer == CacheLayer.REMOTE and remote_hit_enabled: + return EvictAndFallback("force_remote_refetch") + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_remote_on_cache_hit_evicts_and_falls_back_to_source", + on_cache_hit=hook, + ) + def cached_func(test: int) -> int: + nonlocal source_calls + source_calls += 1 + return source_calls + + with gcache.enable(): + assert cached_func(test=1) == 1 + remote_hit_enabled = True + assert cached_func(test=1) == 2 + remote_hit_enabled = False + assert cached_func(test=1) == 2 + + assert source_calls == 2 + + +def test_on_cache_hit_exception_bypasses_current_layer(gcache: GCache) -> None: + source_calls = 0 + + def hook(ctx, value): # type: ignore[no-untyped-def] + if ctx.layer == CacheLayer.LOCAL: + raise RuntimeError("boom") + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_on_cache_hit_exception_bypasses_current_layer", + on_cache_hit=hook, + ) + def cached_func(test: int) -> dict[str, object]: + nonlocal source_calls + source_calls += 1 + return {"source_calls": source_calls} + + with gcache.enable(): + assert cached_func(test=1) == {"source_calls": 1} + assert cached_func(test=1) == {"source_calls": 1} + + local_cache = get_local_cache_for_use_case(gcache, "test_on_cache_hit_exception_bypasses_current_layer") + assert len(local_cache.caches["test_on_cache_hit_exception_bypasses_current_layer"]) == 1 + + assert source_calls == 1 + + +@pytest.mark.asyncio +async def test_async_on_cache_hit_hook_is_awaited(gcache: GCache) -> None: + source_calls = 0 + hook_layers: list[CacheLayer] = [] + + async def hook(ctx, value): # type: ignore[no-untyped-def] + await asyncio.sleep(0) + hook_layers.append(ctx.layer) + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case="test_async_on_cache_hit_hook_is_awaited", + on_cache_hit=hook, + ) + async def cached_func(test: int) -> int: + nonlocal source_calls + source_calls += 1 + return source_calls + + with gcache.enable(): + assert await cached_func(test=1) == 1 + assert await cached_func(test=1) == 1 + + assert source_calls == 1 + assert hook_layers == [CacheLayer.LOCAL] + + +def test_invalid_on_cache_hit_decision_bypasses_current_layer( + gcache: GCache, + cache_config_provider: FakeCacheConfigProvider, +) -> None: + source_calls = 0 + return_invalid_decision = True + use_case = "test_invalid_on_cache_hit_decision_bypasses_current_layer" + + cache_config_provider.configs[use_case] = GCacheKeyConfig( + ttl_sec={CacheLayer.LOCAL: 60, CacheLayer.REMOTE: 60}, + ramp={CacheLayer.LOCAL: 100, CacheLayer.REMOTE: 0}, + ) + + def hook(ctx, value): # type: ignore[no-untyped-def] + if ctx.layer == CacheLayer.LOCAL and return_invalid_decision: + return object() + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case=use_case, + on_cache_hit=hook, + ) + def cached_func(test: int) -> int: + nonlocal source_calls + source_calls += 1 + return source_calls + + with gcache.enable(): + assert cached_func(test=1) == 1 + assert cached_func(test=1) == 2 + + return_invalid_decision = False + assert cached_func(test=1) == 1 + + assert source_calls == 2 + + @pytest.mark.asyncio async def test_gcache_async( gcache: GCache, redis_server: redislite.Redis, reset_prometheus_registry: Generator @@ -327,7 +656,7 @@ def cached_func(test: int = 123) -> int: def test_error_in_cache(gcache: GCache, cache_config_provider: FakeCacheConfigProvider) -> None: class FailingCache(LocalCache): - async def get(self, key: GCacheKey, fallback: Fallback) -> None: + async def get(self, key: GCacheKey, fallback: Fallback, **kwargs) -> None: # type: ignore[no-untyped-def] raise Exception("I'm giving up!") gcache._cache = CacheController(FailingCache(cache_config_provider), cache_config_provider) # type: ignore[assignment] @@ -452,6 +781,53 @@ async def cached_func(test: int) -> int: assert keys[0] == b"{urn:galileo:test:Test:123}#watermark" +@pytest.mark.asyncio +async def test_invalidation_skips_remote_hook( + gcache: GCache, + cache_config_provider: FakeCacheConfigProvider, +) -> None: + use_case = "test_invalidation_skips_remote_hook" + source_calls = 0 + remote_hook_calls = 0 + + cache_config_provider.configs[use_case] = GCacheKeyConfig( + ttl_sec={CacheLayer.LOCAL: 60, CacheLayer.REMOTE: 60}, + ramp={CacheLayer.LOCAL: 0, CacheLayer.REMOTE: 100}, + ) + + def hook(ctx, value): # type: ignore[no-untyped-def] + nonlocal remote_hook_calls + if ctx.layer == CacheLayer.REMOTE: + remote_hook_calls += 1 + return ReturnCached() + + @gcache.cached( + key_type="Test", + id_arg="test", + use_case=use_case, + track_for_invalidation=True, + on_cache_hit=hook, + ) + async def cached_func(test: int) -> int: + nonlocal source_calls + source_calls += 1 + return source_calls + + with gcache.enable(): + assert await cached_func(123) == 1 + + await gcache.ainvalidate("Test", "123") + await asyncio.sleep(0.01) + + assert await cached_func(123) == 2 + assert remote_hook_calls == 0 + + assert await cached_func(123) == 2 + + assert source_calls == 2 + assert remote_hook_calls == 1 + + @pytest.mark.asyncio async def test_flush_all(gcache: GCache, redis_server: redislite.Redis) -> None: with gcache.enable():