Skip to content

Commit 55e0886

Browse files
committed
feat: Implement response caching enhancement (Task 2)
Cache validation results to reduce redundant LLM calls. Expected savings: 20-30% cost reduction.
1 parent cf03c51 commit 55e0886

4 files changed

Lines changed: 422 additions & 13 deletions

File tree

backend/api/routers/chat_router.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2698,19 +2698,53 @@ async def _handle_validation_with_fallback(
26982698
context_quality = context.get("context_quality", None)
26992699
avg_similarity = context.get("avg_similarity_score", None)
27002700

2701-
# Run validation with context quality info
2702-
# Tier 3.5: Pass context quality, is_philosophical, and is_religion_roleplay to ValidatorChain
2703-
# CRITICAL: Pass context dict to enable foundational knowledge detection in CitationRequired
2704-
validation_result = chain.run(
2705-
raw_response,
2706-
ctx_docs,
2707-
context_quality=context_quality,
2708-
avg_similarity=avg_similarity,
2709-
is_philosophical=is_philosophical,
2710-
is_religion_roleplay=is_religion_roleplay,
2711-
user_question=chat_request.message, # Pass user question for FactualHallucinationValidator
2712-
context=context # Pass context dict for foundational knowledge detection
2713-
)
2701+
# Task 2: Response Caching Enhancement - Cache validation results
2702+
# Check cache before running expensive validation chain
2703+
try:
2704+
from backend.utils.cache_utils import get_cache_key, get_from_cache, set_to_cache
2705+
2706+
# Generate cache key from query and context
2707+
cache_key = get_cache_key("validation", chat_request.message, context)
2708+
2709+
# Check cache
2710+
cached_validation = get_from_cache(cache_key)
2711+
if cached_validation is not None:
2712+
logger.info(f"✅ Validation cache HIT for query: {chat_request.message[:50]}...")
2713+
validation_result = cached_validation
2714+
else:
2715+
# Cache miss - run validation
2716+
logger.debug(f"⏳ Validation cache MISS, running validation for: {chat_request.message[:50]}...")
2717+
2718+
# Run validation with context quality info
2719+
# Tier 3.5: Pass context quality, is_philosophical, and is_religion_roleplay to ValidatorChain
2720+
# CRITICAL: Pass context dict to enable foundational knowledge detection in CitationRequired
2721+
validation_result = chain.run(
2722+
raw_response,
2723+
ctx_docs,
2724+
context_quality=context_quality,
2725+
avg_similarity=avg_similarity,
2726+
is_philosophical=is_philosophical,
2727+
is_religion_roleplay=is_religion_roleplay,
2728+
user_question=chat_request.message, # Pass user question for FactualHallucinationValidator
2729+
context=context # Pass context dict for foundational knowledge detection
2730+
)
2731+
2732+
# Cache result (TTL: 1 hour)
2733+
set_to_cache(cache_key, validation_result, ttl=3600)
2734+
logger.debug(f"💾 Cached validation result (TTL: 3600s)")
2735+
except Exception as cache_error:
2736+
# If caching fails, just run validation normally
2737+
logger.warning(f"⚠️ Cache error, running validation without cache: {cache_error}")
2738+
validation_result = chain.run(
2739+
raw_response,
2740+
ctx_docs,
2741+
context_quality=context_quality,
2742+
avg_similarity=avg_similarity,
2743+
is_philosophical=is_philosophical,
2744+
is_religion_roleplay=is_religion_roleplay,
2745+
user_question=chat_request.message,
2746+
context=context
2747+
)
27142748

27152749
# Tier 3.5: If context quality is low, inject warning into prompt for next iteration
27162750
# For now, we'll handle this in the prompt building phase

backend/utils/cache_decorators.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Cache Decorators for StillMe
3+
4+
Provides decorators for caching expensive operations like validation.
5+
"""
6+
7+
import logging
8+
from typing import Callable, Any, Optional
9+
from functools import wraps
10+
11+
from backend.utils.cache_utils import get_cache_key, get_from_cache, set_to_cache
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def cache_validation_result(ttl: int = 3600):
17+
"""
18+
Decorator to cache validation results
19+
20+
Caches validation results based on query hash and context hash.
21+
Skips validation if cached result exists.
22+
23+
Args:
24+
ttl: Time to live in seconds (default: 1 hour)
25+
26+
Usage:
27+
@cache_validation_result(ttl=3600)
28+
async def validate_response(response: str, query: str, context: Dict):
29+
# Validation logic
30+
return validation_result
31+
"""
32+
def decorator(func: Callable) -> Callable:
33+
@wraps(func)
34+
async def wrapper(*args, **kwargs):
35+
# Extract query and context from kwargs
36+
query = kwargs.get('query', '')
37+
context = kwargs.get('context', {})
38+
39+
# Generate cache key
40+
cache_key = get_cache_key("validation", query, context)
41+
42+
# Check cache
43+
cached_result = get_from_cache(cache_key)
44+
if cached_result is not None:
45+
logger.info(f"✅ Validation cache HIT for query: {query[:50]}...")
46+
return cached_result
47+
48+
# Execute function
49+
logger.debug(f"⏳ Validation cache MISS, executing validation for: {query[:50]}...")
50+
result = await func(*args, **kwargs)
51+
52+
# Cache result
53+
set_to_cache(cache_key, result, ttl=ttl)
54+
logger.debug(f"💾 Cached validation result (TTL: {ttl}s)")
55+
56+
return result
57+
58+
return wrapper
59+
return decorator
60+
61+
62+
def cache_expensive_operation(prefix: str, ttl: int = 3600, key_func: Optional[Callable] = None):
63+
"""
64+
Generic decorator to cache expensive operations
65+
66+
Args:
67+
prefix: Cache key prefix
68+
ttl: Time to live in seconds
69+
key_func: Optional function to generate cache key from args/kwargs
70+
71+
Usage:
72+
@cache_expensive_operation("llm_response", ttl=1800)
73+
async def generate_response(query: str):
74+
# Expensive operation
75+
return result
76+
"""
77+
def decorator(func: Callable) -> Callable:
78+
@wraps(func)
79+
async def wrapper(*args, **kwargs):
80+
# Generate cache key
81+
if key_func:
82+
cache_key = key_func(*args, **kwargs)
83+
else:
84+
# Default: use first arg or query kwarg
85+
if args:
86+
cache_key = get_cache_key(prefix, str(args[0]))
87+
elif 'query' in kwargs:
88+
cache_key = get_cache_key(prefix, kwargs['query'])
89+
else:
90+
# No cache key possible, skip caching
91+
return await func(*args, **kwargs)
92+
93+
# Check cache
94+
cached_result = get_from_cache(cache_key)
95+
if cached_result is not None:
96+
logger.debug(f"Cache HIT: {cache_key[:50]}...")
97+
return cached_result
98+
99+
# Execute function
100+
result = await func(*args, **kwargs)
101+
102+
# Cache result
103+
set_to_cache(cache_key, result, ttl=ttl)
104+
105+
return result
106+
107+
return wrapper
108+
return decorator
109+

backend/utils/cache_utils.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
Cache Utilities for StillMe
3+
4+
Provides utilities for caching validation results and other expensive operations.
5+
Uses Redis if available, falls back to in-memory cache.
6+
"""
7+
8+
import logging
9+
import hashlib
10+
import json
11+
from typing import Dict, Any, Optional
12+
from functools import wraps
13+
14+
logger = logging.getLogger(__name__)
15+
16+
# Try to import Redis cache
17+
try:
18+
from backend.services.redis_cache import get_cache_service
19+
REDIS_AVAILABLE = True
20+
except ImportError:
21+
REDIS_AVAILABLE = False
22+
get_cache_service = None
23+
24+
# In-memory fallback cache
25+
_in_memory_cache: Dict[str, Dict[str, Any]] = {}
26+
27+
28+
def hash_query(query: str) -> str:
29+
"""
30+
Generate hash for query (normalized)
31+
32+
Args:
33+
query: User query text
34+
35+
Returns:
36+
MD5 hash (first 16 chars)
37+
"""
38+
if not query:
39+
return "empty"
40+
41+
# Normalize: lowercase, strip whitespace
42+
normalized = query.lower().strip()
43+
return hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
44+
45+
46+
def hash_context(context: Dict[str, Any]) -> str:
47+
"""
48+
Generate hash for context (document IDs + similarities)
49+
50+
Args:
51+
context: RAG context dictionary
52+
53+
Returns:
54+
MD5 hash (first 16 chars)
55+
"""
56+
if not context:
57+
return "no_context"
58+
59+
# Extract document IDs and similarities
60+
docs = context.get("knowledge_docs", [])
61+
if not docs:
62+
return "no_docs"
63+
64+
# Sort for consistent hashing
65+
doc_ids = sorted([str(doc.get("id", "")) for doc in docs if doc.get("id")])
66+
similarities = sorted([round(doc.get("similarity", 0.0), 3) for doc in docs])
67+
68+
# Create hash from IDs and similarities
69+
context_str = json.dumps({"ids": doc_ids, "sims": similarities}, sort_keys=True)
70+
return hashlib.md5(context_str.encode('utf-8')).hexdigest()[:16]
71+
72+
73+
def get_cache_key(prefix: str, query: str, context: Optional[Dict[str, Any]] = None) -> str:
74+
"""
75+
Generate cache key from query and context
76+
77+
Args:
78+
prefix: Cache key prefix (e.g., "validation")
79+
query: User query
80+
context: Optional RAG context
81+
82+
Returns:
83+
Cache key string
84+
"""
85+
query_hash = hash_query(query)
86+
context_hash = hash_context(context) if context else "no_context"
87+
return f"{prefix}:{query_hash}:{context_hash}"
88+
89+
90+
def get_from_cache(cache_key: str) -> Optional[Any]:
91+
"""
92+
Get value from cache (Redis or in-memory)
93+
94+
Args:
95+
cache_key: Cache key
96+
97+
Returns:
98+
Cached value or None if not found
99+
"""
100+
# Try Redis first
101+
if REDIS_AVAILABLE:
102+
try:
103+
cache_service = get_cache_service()
104+
if cache_service:
105+
cached = cache_service.get(cache_key)
106+
if cached:
107+
logger.debug(f"Cache HIT (Redis): {cache_key[:50]}...")
108+
return cached
109+
except Exception as e:
110+
logger.debug(f"Redis cache error (falling back to memory): {e}")
111+
112+
# Fallback to in-memory
113+
if cache_key in _in_memory_cache:
114+
cached_data = _in_memory_cache[cache_key]
115+
# Check TTL (simple implementation)
116+
import time
117+
if time.time() < cached_data.get("expires_at", 0):
118+
logger.debug(f"Cache HIT (Memory): {cache_key[:50]}...")
119+
return cached_data.get("value")
120+
else:
121+
# Expired, remove it
122+
del _in_memory_cache[cache_key]
123+
124+
logger.debug(f"Cache MISS: {cache_key[:50]}...")
125+
return None
126+
127+
128+
def set_to_cache(cache_key: str, value: Any, ttl: int = 3600) -> None:
129+
"""
130+
Set value to cache (Redis or in-memory)
131+
132+
Args:
133+
cache_key: Cache key
134+
value: Value to cache
135+
ttl: Time to live in seconds (default: 1 hour)
136+
"""
137+
# Try Redis first
138+
if REDIS_AVAILABLE:
139+
try:
140+
cache_service = get_cache_service()
141+
if cache_service:
142+
cache_service.set(cache_key, value, ttl=ttl)
143+
logger.debug(f"Cached (Redis): {cache_key[:50]}... (TTL: {ttl}s)")
144+
return
145+
except Exception as e:
146+
logger.debug(f"Redis cache error (falling back to memory): {e}")
147+
148+
# Fallback to in-memory
149+
import time
150+
_in_memory_cache[cache_key] = {
151+
"value": value,
152+
"expires_at": time.time() + ttl
153+
}
154+
logger.debug(f"Cached (Memory): {cache_key[:50]}... (TTL: {ttl}s)")
155+
156+
157+
def clear_cache(cache_key: Optional[str] = None) -> None:
158+
"""
159+
Clear cache (specific key or all)
160+
161+
Args:
162+
cache_key: Specific key to clear, or None to clear all
163+
"""
164+
if cache_key:
165+
# Clear specific key
166+
if REDIS_AVAILABLE:
167+
try:
168+
cache_service = get_cache_service()
169+
if cache_service:
170+
cache_service.delete(cache_key)
171+
except Exception:
172+
pass
173+
174+
if cache_key in _in_memory_cache:
175+
del _in_memory_cache[cache_key]
176+
else:
177+
# Clear all
178+
if REDIS_AVAILABLE:
179+
try:
180+
cache_service = get_cache_service()
181+
if cache_service:
182+
cache_service.clear()
183+
except Exception:
184+
pass
185+
186+
_in_memory_cache.clear()
187+
188+
189+
def get_cache_stats() -> Dict[str, Any]:
190+
"""
191+
Get cache statistics
192+
193+
Returns:
194+
Dictionary with cache stats
195+
"""
196+
stats = {
197+
"redis_available": REDIS_AVAILABLE,
198+
"in_memory_size": len(_in_memory_cache),
199+
"in_memory_keys": list(_in_memory_cache.keys())[:10] # First 10 keys
200+
}
201+
202+
if REDIS_AVAILABLE:
203+
try:
204+
cache_service = get_cache_service()
205+
if cache_service:
206+
redis_stats = cache_service.get_stats()
207+
stats.update(redis_stats)
208+
except Exception as e:
209+
stats["redis_error"] = str(e)
210+
211+
return stats
212+

0 commit comments

Comments
 (0)