Skip to content

Commit 5067b96

Browse files
authored
feat: cache user settings reads with Redis (#863)
feat: cache user settings reads with Redis (#863)
2 parents 96124d5 + 2083ce3 commit 5067b96

6 files changed

Lines changed: 262 additions & 3 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Async Redis client singleton for application-level caching."""
2+
3+
import redis.asyncio as redis
4+
5+
from src.workers.broker import REDIS_URL
6+
7+
_client: redis.Redis | None = None
8+
9+
10+
async def get_redis_client() -> redis.Redis:
11+
"""Return a shared async Redis client instance."""
12+
global _client
13+
if _client is None:
14+
_client = redis.from_url(REDIS_URL, decode_responses=True)
15+
return _client

backend/src/repos/user_settings_repo.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,71 @@
1+
import json
12
from datetime import datetime, timezone
23
import uuid
34
from typing import Any, Optional
45

56
from src.repos.base_repo import BaseRepo
67
from src.schemas.entities.settings import PersistedContextFile, ProviderKeyStatus, SandboxType, UserSettings
78
from src.utils.security import encrypt_value, decrypt_value
9+
from src.utils.logger import logger
810
from src.constants import UserTokenKey
911

1012
# Single well-known key for the one settings record per user
1113
_SETTINGS_KEY = "default"
1214

15+
# Redis cache TTL in seconds (5 minutes safety net; active invalidation is primary)
16+
_CACHE_TTL = 300
17+
1318

1419
class UserSettingsRepo(BaseRepo):
1520
def __init__(self, user_id: str, store):
1621
super().__init__(user_id, store, "user_settings")
1722

23+
# ------------------------------------------------------------------
24+
# Cache helpers
25+
# ------------------------------------------------------------------
26+
27+
def _cache_key(self) -> str:
28+
"""Redis key for this user's cached settings."""
29+
return f"user_settings:{self.user_id}"
30+
31+
async def _get_cached(self) -> UserSettings | None:
32+
"""Try to read settings from Redis cache. Returns None on miss or error."""
33+
try:
34+
from src.common.utils.redis_cache import get_redis_client
35+
36+
client = await get_redis_client()
37+
raw = await client.get(self._cache_key())
38+
if raw is not None:
39+
logger.info("settings_cache_hit user_id=%s", self.user_id)
40+
return UserSettings.model_validate(json.loads(raw))
41+
logger.info("settings_cache_miss user_id=%s", self.user_id)
42+
except Exception:
43+
logger.warning("settings_cache_error op=get user_id=%s", self.user_id, exc_info=True)
44+
return None
45+
46+
async def _set_cached(self, settings: UserSettings) -> None:
47+
"""Write settings to Redis cache with TTL."""
48+
try:
49+
from src.common.utils.redis_cache import get_redis_client
50+
51+
client = await get_redis_client()
52+
data = json.dumps(settings.model_dump(exclude_none=True, mode="json"))
53+
await client.set(self._cache_key(), data, ex=_CACHE_TTL)
54+
logger.info("settings_cache_set user_id=%s ttl=%d", self.user_id, _CACHE_TTL)
55+
except Exception:
56+
logger.warning("settings_cache_error op=set user_id=%s", self.user_id, exc_info=True)
57+
58+
async def _invalidate_cache(self) -> None:
59+
"""Delete cached settings from Redis."""
60+
try:
61+
from src.common.utils.redis_cache import get_redis_client
62+
63+
client = await get_redis_client()
64+
await client.delete(self._cache_key())
65+
logger.info("settings_cache_invalidate user_id=%s", self.user_id)
66+
except Exception:
67+
logger.warning("settings_cache_error op=invalidate user_id=%s", self.user_id, exc_info=True)
68+
1869
# ------------------------------------------------------------------
1970
# Helpers
2071
# ------------------------------------------------------------------
@@ -26,17 +77,28 @@ def _validate_provider(self, provider: str) -> None:
2677
raise ValueError(f"Invalid provider '{provider}'. Must be one of: {valid}")
2778

2879
async def _get_or_create(self) -> UserSettings:
29-
"""Return existing settings or create an empty record."""
80+
"""Return existing settings, checking Redis cache first."""
81+
# Check cache
82+
cached = await self._get_cached()
83+
if cached is not None:
84+
return cached
85+
86+
# Cache miss -- hit database
3087
item = await self._get(_SETTINGS_KEY)
3188
if item:
32-
return UserSettings.model_validate(item.value)
89+
settings = UserSettings.model_validate(item.value)
90+
await self._set_cached(settings)
91+
return settings
92+
93+
# First time -- create empty record
3394
settings = UserSettings(
3495
id=str(uuid.uuid4()),
3596
user_id=self.user_id,
3697
created_at=datetime.now(timezone.utc),
3798
updated_at=datetime.now(timezone.utc),
3899
)
39100
await self._set(_SETTINGS_KEY, settings)
101+
await self._set_cached(settings)
40102
return settings
41103

42104
def _decrypt_keys(self, settings: UserSettings) -> dict[str, str]:
@@ -107,6 +169,7 @@ async def set_default_model(self, model: Optional[str]) -> UserSettings:
107169
settings.default_model = model
108170
settings.updated_at = datetime.now(timezone.utc)
109171
await self._set(_SETTINGS_KEY, settings)
172+
await self._invalidate_cache()
110173
return settings
111174

112175
async def set_default_sandbox(self, sandbox: Optional[str]) -> UserSettings:
@@ -118,6 +181,7 @@ async def set_default_sandbox(self, sandbox: Optional[str]) -> UserSettings:
118181
settings.default_sandbox = sandbox
119182
settings.updated_at = datetime.now(timezone.utc)
120183
await self._set(_SETTINGS_KEY, settings)
184+
await self._invalidate_cache()
121185
return settings
122186

123187
# Mapping: PATCH request key -> entity field name
@@ -150,6 +214,7 @@ async def patch_defaults(self, data: dict) -> UserSettings:
150214
setattr(settings, field, value)
151215
settings.updated_at = datetime.now(timezone.utc)
152216
await self._set(_SETTINGS_KEY, settings)
217+
await self._invalidate_cache()
153218
return settings
154219

155220
async def upsert_provider_key(self, provider: str, api_key: str) -> UserSettings:
@@ -160,6 +225,7 @@ async def upsert_provider_key(self, provider: str, api_key: str) -> UserSettings
160225
settings.encrypted_keys = self._encrypt_keys(keys)
161226
settings.updated_at = datetime.now(timezone.utc)
162227
await self._set(_SETTINGS_KEY, settings)
228+
await self._invalidate_cache()
163229
return settings
164230

165231
async def delete_provider_key(self, provider: str) -> UserSettings:
@@ -170,6 +236,7 @@ async def delete_provider_key(self, provider: str) -> UserSettings:
170236
settings.encrypted_keys = self._encrypt_keys(keys)
171237
settings.updated_at = datetime.now(timezone.utc)
172238
await self._set(_SETTINGS_KEY, settings)
239+
await self._invalidate_cache()
173240
return settings
174241

175242
async def get_all_decrypted_keys(self) -> dict[str, str]:
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Tests for Redis cache layer in UserSettingsRepo."""
2+
3+
import json
4+
import unittest
5+
from unittest.mock import AsyncMock, patch
6+
7+
import fakeredis.aioredis
8+
from langgraph.store.memory import InMemoryStore
9+
10+
from src.repos.user_settings_repo import UserSettingsRepo, _CACHE_TTL
11+
12+
TEST_USER_ID = "cache-test-user-001"
13+
14+
15+
def _mock_encrypt(value: dict) -> str:
16+
return "ENC:" + json.dumps(value, sort_keys=True)
17+
18+
19+
def _mock_decrypt(value: str) -> dict:
20+
if not value.startswith("ENC:"):
21+
raise ValueError("Bad ciphertext")
22+
return json.loads(value[4:])
23+
24+
25+
@patch("src.repos.user_settings_repo.encrypt_value", side_effect=_mock_encrypt)
26+
@patch("src.repos.user_settings_repo.decrypt_value", side_effect=_mock_decrypt)
27+
class TestUserSettingsCache(unittest.IsolatedAsyncioTestCase):
28+
"""Tests for the Redis cache behaviour in UserSettingsRepo."""
29+
30+
async def asyncSetUp(self):
31+
self.store = InMemoryStore()
32+
self.redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
33+
self.repo = UserSettingsRepo(user_id=TEST_USER_ID, store=self.store)
34+
# Patch get_redis_client to return our fake Redis
35+
self._redis_patch = patch(
36+
"src.common.utils.redis_cache.get_redis_client",
37+
return_value=self.redis,
38+
)
39+
self._redis_patch.start()
40+
41+
async def asyncTearDown(self):
42+
self._redis_patch.stop()
43+
await self.redis.aclose()
44+
45+
# ------------------------------------------------------------------
46+
# Cache population
47+
# ------------------------------------------------------------------
48+
49+
async def test_first_call_populates_cache(self, _dec, _enc):
50+
"""First _get_or_create hits DB and writes to Redis cache."""
51+
settings = await self.repo._get_or_create()
52+
cached_raw = await self.redis.get(self.repo._cache_key())
53+
self.assertIsNotNone(cached_raw)
54+
cached = json.loads(cached_raw)
55+
self.assertEqual(cached["user_id"], TEST_USER_ID)
56+
self.assertEqual(cached["id"], settings.id)
57+
58+
async def test_second_call_uses_cache(self, _dec, _enc):
59+
"""Second _get_or_create returns from cache without hitting DB."""
60+
await self.repo._get_or_create()
61+
62+
# Patch DB access to verify it's not called
63+
with patch.object(self.repo, "_get", new_callable=AsyncMock) as mock_get:
64+
settings = await self.repo._get_or_create()
65+
mock_get.assert_not_called()
66+
self.assertEqual(settings.user_id, TEST_USER_ID)
67+
68+
async def test_cache_ttl_is_set(self, _dec, _enc):
69+
"""Cached entry has a TTL set."""
70+
await self.repo._get_or_create()
71+
ttl = await self.redis.ttl(self.repo._cache_key())
72+
self.assertGreater(ttl, 0)
73+
self.assertLessEqual(ttl, _CACHE_TTL)
74+
75+
# ------------------------------------------------------------------
76+
# Cache invalidation on mutations
77+
# ------------------------------------------------------------------
78+
79+
async def test_set_default_model_invalidates(self, _dec, _enc):
80+
"""set_default_model removes the cached entry."""
81+
await self.repo._get_or_create()
82+
self.assertIsNotNone(await self.redis.get(self.repo._cache_key()))
83+
await self.repo.set_default_model("openai/gpt-4")
84+
self.assertIsNone(await self.redis.get(self.repo._cache_key()))
85+
86+
async def test_set_default_sandbox_invalidates(self, _dec, _enc):
87+
"""set_default_sandbox removes the cached entry."""
88+
await self.repo._get_or_create()
89+
await self.repo.set_default_sandbox("daytona")
90+
self.assertIsNone(await self.redis.get(self.repo._cache_key()))
91+
92+
async def test_patch_defaults_invalidates(self, _dec, _enc):
93+
"""patch_defaults removes the cached entry."""
94+
await self.repo._get_or_create()
95+
await self.repo.patch_defaults({"model": "test-model"})
96+
self.assertIsNone(await self.redis.get(self.repo._cache_key()))
97+
98+
async def test_upsert_provider_key_invalidates(self, _dec, _enc):
99+
"""upsert_provider_key removes the cached entry."""
100+
await self.repo._get_or_create()
101+
await self.repo.upsert_provider_key("OPENAI_API_KEY", "sk-test")
102+
self.assertIsNone(await self.redis.get(self.repo._cache_key()))
103+
104+
async def test_delete_provider_key_invalidates(self, _dec, _enc):
105+
"""delete_provider_key removes the cached entry."""
106+
await self.repo._get_or_create()
107+
await self.repo.delete_provider_key("OPENAI_API_KEY")
108+
self.assertIsNone(await self.redis.get(self.repo._cache_key()))
109+
110+
# ------------------------------------------------------------------
111+
# Graceful degradation
112+
# ------------------------------------------------------------------
113+
114+
async def test_redis_failure_falls_through_to_db(self, _dec, _enc):
115+
"""When Redis raises an error, _get_or_create still works via DB."""
116+
# Make Redis raise on all operations
117+
self._redis_patch.stop()
118+
broken_redis = AsyncMock()
119+
broken_redis.get = AsyncMock(side_effect=ConnectionError("Redis down"))
120+
broken_redis.set = AsyncMock(side_effect=ConnectionError("Redis down"))
121+
broken_redis.delete = AsyncMock(side_effect=ConnectionError("Redis down"))
122+
self._redis_patch = patch(
123+
"src.common.utils.redis_cache.get_redis_client",
124+
return_value=broken_redis,
125+
)
126+
self._redis_patch.start()
127+
128+
# Should still work -- falls through to DB
129+
settings = await self.repo._get_or_create()
130+
self.assertEqual(settings.user_id, TEST_USER_ID)
131+
132+
# Mutations should also work without Redis
133+
await self.repo.set_default_model("test-model")
134+
settings, _ = await self.repo.get_settings()
135+
self.assertEqual(settings.default_model, "test-model")
136+
137+
138+
if __name__ == "__main__":
139+
unittest.main()

backend/tests/unit/repos/test_user_settings_repo.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ class TestUserSettingsRepo(unittest.IsolatedAsyncioTestCase):
3333
async def asyncSetUp(self):
3434
self.store = InMemoryStore()
3535
self.repo = UserSettingsRepo(user_id=TEST_USER_ID, store=self.store)
36+
# Disable Redis cache for unit tests -- cache tests live in test_user_settings_cache.py
37+
self._cache_patches = [
38+
patch.object(self.repo, "_get_cached", return_value=None),
39+
patch.object(self.repo, "_set_cached", return_value=None),
40+
patch.object(self.repo, "_invalidate_cache", return_value=None),
41+
]
42+
for p in self._cache_patches:
43+
p.start()
44+
45+
async def asyncTearDown(self):
46+
for p in self._cache_patches:
47+
p.stop()
3648

3749
# ------------------------------------------------------------------
3850
# get_settings

backend/tests/unit/routes/test_settings_routes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
from typing import AsyncGenerator
5-
from unittest.mock import MagicMock, patch
5+
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import pytest
88
from httpx import ASGITransport, AsyncClient
@@ -75,6 +75,12 @@ async def override_auth():
7575
}
7676
)
7777

78+
# Disable Redis cache so tests use only the InMemoryStore
79+
fake_redis = AsyncMock()
80+
fake_redis.get = AsyncMock(return_value=None)
81+
fake_redis.set = AsyncMock(return_value=None)
82+
fake_redis.delete = AsyncMock(return_value=None)
83+
7884
with (
7985
patch(
8086
"src.repos.user_settings_repo.encrypt_value",
@@ -84,6 +90,10 @@ async def override_auth():
8490
"src.repos.user_settings_repo.decrypt_value",
8591
side_effect=_mock_decrypt,
8692
),
93+
patch(
94+
"src.common.utils.redis_cache.get_redis_client",
95+
return_value=fake_redis,
96+
),
8797
):
8898
transport = ASGITransport(app=app)
8999
async with AsyncClient(transport=transport, base_url="http://test") as client:

backend/tests/unit/services/test_context_files.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1+
from unittest.mock import AsyncMock, patch
2+
13
from langgraph.store.memory import InMemoryStore
24

35
import pytest
46

57
from src.repos.user_settings_repo import UserSettingsRepo
68
from src.services.context_files import resolve_context_files, select_memory_sources
79

10+
# Disable Redis cache for all tests in this module
11+
_fake_redis = AsyncMock()
12+
_fake_redis.get = AsyncMock(return_value=None)
13+
_fake_redis.set = AsyncMock(return_value=None)
14+
_fake_redis.delete = AsyncMock(return_value=None)
15+
16+
pytestmark = pytest.mark.usefixtures("_disable_redis_cache")
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def _disable_redis_cache():
21+
with patch("src.common.utils.redis_cache.get_redis_client", return_value=_fake_redis):
22+
yield
23+
824

925
@pytest.mark.asyncio
1026
async def test_settings_override_memory_by_path():

0 commit comments

Comments
 (0)