Skip to content

Latest commit

 

History

History
288 lines (212 loc) · 8.44 KB

File metadata and controls

288 lines (212 loc) · 8.44 KB

Cookbook: Cache Invalidation Strategies

Patterns for keeping your RAG cache fresh and accurate.

Invalidation Strategies Overview

Strategy Use Case Complexity Freshness
TTL-based General purpose Low Good
Event-driven Real-time updates Medium Excellent
Version-based Schema changes Low Good
Semantic drift Content changes High Excellent

1. TTL-Based Invalidation

The simplest approach - entries expire after a set time.

# Set TTL when storing
await cache.store(query, response, model)
client.expire(key, 3600)  # 1 hour

# Sliding window - extend on hit
async def lookup_with_sliding_ttl(query: str):
    result = await cache.lookup(query)
    if result.hit:
        # Extend TTL on hit
        client.expire(result.key, 3600)
    return result

Tiered TTL

Different TTLs for different content types:

TTL_CONFIG = {
    "static_faq": 86400 * 7,    # 1 week
    "dynamic_content": 3600,    # 1 hour
    "real_time": 60,            # 1 minute
    "user_specific": 1800,      # 30 minutes
}

async def store_with_tiered_ttl(query, response, content_type):
    key = await cache.store(query, response)
    ttl = TTL_CONFIG.get(content_type, 3600)
    client.expire(key, ttl)

2. Event-Driven Invalidation

Invalidate cache when source data changes.

Publisher Pattern

# When source document is updated
async def on_document_updated(doc_id: str):
    """Invalidate all cache entries related to this document."""
    
    # Find related cache entries
    pattern = f"cache:*:{doc_id}:*"
    keys = client.keys(pattern)
    
    if keys:
        client.delete(*keys)
        print(f"Invalidated {len(keys)} cache entries for {doc_id}")
    
    # Publish event for subscribers
    client.publish("cache:invalidate", json.dumps({
        "type": "document_updated",
        "doc_id": doc_id,
        "timestamp": time.time(),
    }))

Subscriber Pattern

import asyncio

async def cache_invalidation_listener():
    """Listen for invalidation events."""
    
    pubsub = client.pubsub()
    pubsub.subscribe("cache:invalidate")
    
    async for message in pubsub.listen():
        if message["type"] == "message":
            event = json.loads(message["data"])
            
            if event["type"] == "document_updated":
                await invalidate_by_document(event["doc_id"])
            elif event["type"] == "model_updated":
                await invalidate_by_model(event["model"])
            elif event["type"] == "full_flush":
                await cache.clear()

3. Version-Based Invalidation

Invalidate by incrementing version in cache keys.

class VersionedCache:
    """Cache with version-based invalidation."""
    
    def __init__(self):
        self.version = 1
        self.version_key = "cache:version"
    
    def _get_versioned_prefix(self) -> str:
        return f"cache:v{self.version}"
    
    async def lookup(self, query: str):
        # Always use current version
        versioned_query = f"{self._get_versioned_prefix()}:{query}"
        return await self.cache.lookup(versioned_query)
    
    async def store(self, query: str, response: str):
        versioned_query = f"{self._get_versioned_prefix()}:{query}"
        return await self.cache.store(versioned_query, response)
    
    async def invalidate_all(self):
        """Increment version to invalidate all entries."""
        self.version += 1
        client.set(self.version_key, self.version)
        print(f"Cache version incremented to {self.version}")
        # Old entries will naturally expire via TTL

4. Semantic Drift Detection

Detect when cached responses may be stale.

class SemanticDriftDetector:
    """Detect semantic drift in cached content."""
    
    def __init__(self, drift_threshold: float = 0.1):
        self.drift_threshold = drift_threshold
        self.reference_embeddings = {}
    
    async def check_drift(self, doc_id: str, new_content: str) -> bool:
        """Check if content has drifted from cached version."""
        
        if doc_id not in self.reference_embeddings:
            return True  # No reference, invalidate
        
        # Generate embedding for new content
        new_embedding = await embeddings.embed(new_content)
        old_embedding = self.reference_embeddings[doc_id]
        
        # Calculate drift
        similarity = cosine_similarity(new_embedding, old_embedding)
        drift = 1 - similarity
        
        if drift > self.drift_threshold:
            print(f"Semantic drift detected for {doc_id}: {drift:.3f}")
            return True
        
        return False
    
    async def update_reference(self, doc_id: str, content: str):
        """Update reference embedding."""
        self.reference_embeddings[doc_id] = await embeddings.embed(content)
    
    async def on_content_change(self, doc_id: str, new_content: str):
        """Handle content change with drift detection."""
        
        if await self.check_drift(doc_id, new_content):
            # Invalidate related cache entries
            await invalidate_by_document(doc_id)
            
            # Update reference
            await self.update_reference(doc_id, new_content)

5. Tag-Based Invalidation

Invalidate by tags/categories.

class TaggedCache:
    """Cache with tag-based invalidation."""
    
    async def store(self, query: str, response: str, tags: list[str]):
        key = await self.cache.store(query, response)
        
        # Track key by tags
        for tag in tags:
            client.sadd(f"cache:tags:{tag}", key)
    
    async def invalidate_by_tag(self, tag: str):
        """Invalidate all entries with this tag."""
        
        tag_key = f"cache:tags:{tag}"
        keys = client.smembers(tag_key)
        
        if keys:
            client.delete(*keys)
            client.delete(tag_key)
            print(f"Invalidated {len(keys)} entries with tag '{tag}'")
    
    async def invalidate_by_tags(self, tags: list[str]):
        """Invalidate all entries with ANY of these tags."""
        
        all_keys = set()
        for tag in tags:
            keys = client.smembers(f"cache:tags:{tag}")
            all_keys.update(keys)
        
        if all_keys:
            client.delete(*all_keys)
            print(f"Invalidated {len(all_keys)} entries")

6. LRU + Size-Based Eviction

Configure Valkey memory policy for automatic eviction.

# valkey.conf
maxmemory 2gb
maxmemory-policy allkeys-lru  # Evict least recently used

Options:

  • volatile-lru: LRU among keys with TTL
  • allkeys-lru: LRU among all keys
  • volatile-lfu: LFU among keys with TTL
  • allkeys-lfu: LFU among all keys
  • volatile-random: Random among keys with TTL
  • allkeys-random: Random among all keys

7. Scheduled Invalidation

Refresh cache on a schedule.

import asyncio
from datetime import datetime, timedelta

class ScheduledInvalidator:
    """Scheduled cache invalidation."""
    
    def __init__(self, cache, refresh_interval_hours: int = 24):
        self.cache = cache
        self.interval = timedelta(hours=refresh_interval_hours)
    
    async def run(self):
        """Run scheduled invalidation loop."""
        while True:
            await asyncio.sleep(self.interval.total_seconds())
            await self.refresh_stale_entries()
    
    async def refresh_stale_entries(self):
        """Find and refresh entries older than threshold."""
        
        threshold = time.time() - self.interval.total_seconds()
        
        # Find old entries
        results = client.ft(index_name).search(
            Query(f"@created_at:[-inf {threshold}]")
            .return_fields("query")
        )
        
        for doc in results.docs:
            # Re-generate and update
            query = doc.query
            new_response = await llm.generate(query)
            await cache.store(query, new_response)
        
        print(f"Refreshed {len(results.docs)} stale entries")

Best Practices

  1. Combine Strategies: Use TTL as baseline + event-driven for important updates
  2. Monitor Staleness: Track cache age and freshness metrics
  3. Graceful Degradation: On cache miss, serve from LLM and update cache
  4. Audit Trail: Log invalidation events for debugging
  5. Test Invalidation: Include invalidation in integration tests

Next Steps