Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions mcpgateway/cache/resource_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@

# Standard
import asyncio
from collections import OrderedDict
from dataclasses import dataclass
import time
from typing import Any, Dict, Optional
from typing import Any, Optional

# First-Party
from mcpgateway.services.logging_service import LoggingService
Expand All @@ -54,7 +55,6 @@ class CacheEntry:

value: Any
expires_at: float
last_access: float


class ResourceCache:
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(self, max_size: int = 1000, ttl: int = 3600):
"""
self.max_size = max_size
self.ttl = ttl
self._cache: Dict[str, CacheEntry] = {}
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = asyncio.Lock()

async def initialize(self) -> None:
Expand Down Expand Up @@ -156,8 +156,9 @@ def get(self, key: str) -> Optional[Any]:
del self._cache[key]
return None

# Update access time
entry.last_access = now
entry.expires_at = now + self.ttl
self._cache.move_to_end(key)

return entry.value

def set(self, key: str, value: Any) -> None:
Expand All @@ -175,16 +176,13 @@ def set(self, key: str, value: Any) -> None:
>>> cache.get('a')
1
"""
now = time.time()

# Check size limit
if len(self._cache) >= self.max_size:
# Remove least recently used
lru_key = min(self._cache.keys(), key=lambda k: self._cache[k].last_access)
del self._cache[lru_key]
if key in self._cache:
self._cache.move_to_end(key)
elif len(self._cache) >= self.max_size:
self._cache.popitem(last=False)

# Add new entry
self._cache[key] = CacheEntry(value=value, expires_at=now + self.ttl, last_access=now)
self._cache[key] = CacheEntry(value=value, expires_at=time.time() + self.ttl)

def delete(self, key: str) -> None:
"""
Expand Down Expand Up @@ -234,3 +232,22 @@ async def _cleanup_loop(self) -> None:
logger.error(f"Cache cleanup error: {e}")

await asyncio.sleep(60) # Run every minute

def __len__(self) -> int:
"""
Get the number of entries in cache.

Args:
None

Returns:
int: Number of entries in cache

Examples:
>>> from mcpgateway.cache.resource_cache import ResourceCache
>>> cache = ResourceCache(max_size=2, ttl=1)
>>> cache.set('a', 1)
>>> len(cache)
1
"""
return len(self._cache)
16 changes: 14 additions & 2 deletions tests/unit/mcpgateway/cache/test_resource_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_set_and_get(cache):
"""Test setting and getting a cache value."""
cache.set("foo", "bar")
assert cache.get("foo") == "bar"
assert len(cache) == 1


def test_get_missing(cache):
Expand All @@ -45,18 +46,24 @@ def test_expiration():
time.sleep(0.15)

# Entry should definitely be expired now
assert len(fast_cache) == 1
assert fast_cache.get("foo") is None
# Entry should be deleted following get operation
assert len(fast_cache) == 0


def test_lru_eviction(cache):
"""Test LRU eviction when max_size is reached."""
cache.set("a", 1)
cache.set("b", 2)
cache.set("c", 3)
# Access 'a' to update its last_access
# Access 'a' to update its position in the ordered cache
assert len(cache) == cache.max_size
assert cache.get("a") == 1
# Add another entry, should evict 'b' (least recently used)
# Add another entry, should evict 'b' (least recently used) and keep cache length
cache.set("d", 4)
assert len(cache) == cache.max_size

assert cache.get("b") is None
assert cache.get("a") == 1
assert cache.get("c") == 3
Expand All @@ -66,17 +73,21 @@ def test_lru_eviction(cache):
def test_delete(cache):
"""Test deleting a cache entry."""
cache.set("foo", "bar")
assert len(cache) == 1
cache.delete("foo")
assert len(cache) == 0
assert cache.get("foo") is None


def test_clear(cache):
"""Test clearing the cache."""
cache.set("foo", "bar")
cache.set("baz", "qux")
assert len(cache) == 2
cache.clear()
assert cache.get("foo") is None
assert cache.get("baz") is None
assert len(cache) == 0


@pytest.mark.asyncio
Expand All @@ -88,6 +99,7 @@ async def test_initialize_and_shutdown_logs(monkeypatch):
cache.set("foo", "bar")
await cache.shutdown()
assert cache.get("foo") is None
assert len(cache) == 0


@pytest.mark.asyncio
Expand Down
Loading