Skip to content

Commit 6850047

Browse files
committed
Add disk cache
1 parent daa0eb7 commit 6850047

4 files changed

Lines changed: 90 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"pydantic[email]>=2.11.7",
1616
"pyperclip>=1.9.0",
1717
"openapi-core>=0.19.5",
18+
"diskcache>=5.6.3",
1819
]
1920

2021
requires-python = ">=3.10"

src/fastmcp/server/middleware/caching.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
from datetime import datetime, timedelta, timezone
44
from typing import Any, ClassVar, Protocol
55

6+
from diskcache import Cache as DiskCacheClient
67
from mcp.types import CallToolRequestParams, ContentBlock
78
from pydantic import BaseModel, ConfigDict
8-
from typing_extensions import Self
9+
from typing_extensions import Self, overload, runtime_checkable
910

1011
from fastmcp.server.middleware.middleware import CallNext, Middleware, MiddlewareContext
1112
from fastmcp.tools.tool import ToolResult
1213

1314

15+
ONE_HOUR_IN_SECONDS = 3600
16+
ONE_GB_IN_BYTES = 1024 * 1024 * 1024
17+
18+
1419
class CacheEntry(BaseModel):
1520
"""A cache entry."""
1621

@@ -48,18 +53,49 @@ def from_tool_result(cls, key: str, value: ToolResult, ttl: int) -> Self:
4853
)
4954

5055

56+
@runtime_checkable
5157
class CacheProtocol(Protocol):
5258
"""A protocol for a cache client."""
5359

54-
async def get(self, key: str) -> ToolResult | None: ...
60+
async def get(self, key: str) -> ToolResult | None:
61+
"""Get a value from the cache."""
62+
63+
async def set(self, key: str, value: ToolResult, ttl: int) -> None:
64+
"""Set a value in the cache."""
65+
66+
async def delete(self, key: str) -> None:
67+
"""Delete a value from the cache."""
68+
69+
70+
class DiskCache(CacheProtocol):
71+
"""A caching client that uses the DiskCache library to cache to disk."""
72+
73+
@overload
74+
def __init__(self, disk_cache: DiskCacheClient):
75+
"""Initialize the disk cache with a diskcache client."""
5576

56-
async def set(self, key: str, value: ToolResult, ttl: int) -> None: ...
77+
@overload
78+
def __init__(self, path: str, size_limit: int = ONE_GB_IN_BYTES):
79+
"""Initialize a 1GB disk cache at the provided path."""
5780

58-
async def delete(self, key: str) -> None: ...
81+
def __init__(
82+
self,
83+
disk_cache: DiskCacheClient | None = None,
84+
path: str | None = None,
85+
size_limit: int = ONE_GB_IN_BYTES,
86+
):
87+
self._cache = disk_cache or DiskCacheClient(
88+
directory=path, size_limit=size_limit
89+
)
5990

60-
async def setup(self) -> None: ...
91+
async def get(self, key: str) -> ToolResult | None:
92+
return self._cache.get(key)
6193

62-
async def clear(self) -> None: ...
94+
async def set(self, key: str, value: ToolResult, ttl: int) -> None:
95+
self._cache.set(key, value, expire=ttl)
96+
97+
async def delete(self, key: str) -> None:
98+
self._cache.delete(key)
6399

64100

65101
class InMemoryCache(CacheProtocol):
@@ -121,14 +157,24 @@ class ResponseCachingMiddleware(Middleware):
121157

122158
def __init__(
123159
self,
124-
cache_backend: CacheProtocol,
160+
cache_backend: CacheProtocol | None = None,
125161
included_tools: list[str] | None = None,
126162
excluded_tools: list[str] | None = None,
127-
default_ttl: int = 3600,
163+
default_ttl: int = ONE_HOUR_IN_SECONDS,
128164
):
165+
"""Initialize the response caching middleware.
166+
167+
Args:
168+
cache_backend: The cache backend to use. If None, an in-memory cache is used.
169+
included_tools: The tools to cache responses from. If None, all tools are cached.
170+
excluded_tools: The tools to not cache responses from. If None, no tools are excluded.
171+
default_ttl: The default TTL for cached responses. Defaults to one hour.
172+
"""
129173
self._default_ttl = default_ttl
130-
self._backend = cache_backend
174+
self._backend = cache_backend or InMemoryCache()
175+
131176
self._stats = CacheStats(hits=0, misses=0)
177+
132178
self._included_tools = included_tools
133179
self._excluded_tools = excluded_tools
134180

tests/server/middleware/test_caching.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for response caching middleware."""
22

3+
import tempfile
34
from datetime import datetime, timedelta, timezone
45
from typing import Any
56
from unittest.mock import AsyncMock, MagicMock
@@ -14,6 +15,7 @@
1415
from fastmcp.server.middleware.caching import (
1516
CacheEntry,
1617
CacheStats,
18+
DiskCache,
1719
InMemoryCache,
1820
ResponseCachingMiddleware,
1921
)
@@ -296,13 +298,30 @@ class TestResponseCachingMiddlewareIntegration:
296298
"""Integration tests with real FastMCP server."""
297299

298300
@pytest.fixture
299-
def caching_server(self, tracking_calculator: TrackingCalculator):
301+
async def disk_cache(self):
302+
with tempfile.TemporaryDirectory() as temp_dir:
303+
yield DiskCache(path=temp_dir)
304+
305+
@pytest.fixture
306+
async def in_memory_cache(self):
307+
return InMemoryCache()
308+
309+
@pytest.fixture(params=["memory", "disk"])
310+
async def caching_server(
311+
self,
312+
tracking_calculator: TrackingCalculator,
313+
request,
314+
disk_cache,
315+
in_memory_cache,
316+
):
300317
"""Create a FastMCP server for caching tests."""
301318
mcp = FastMCP("CachingTestServer")
302319

303-
mcp.add_middleware(
304-
middleware=ResponseCachingMiddleware(cache_backend=InMemoryCache())
305-
)
320+
cache = disk_cache if request.param == "disk" else in_memory_cache
321+
322+
response_caching_middleware = ResponseCachingMiddleware(cache_backend=cache)
323+
324+
mcp.add_middleware(middleware=response_caching_middleware)
306325

307326
tracking_calculator.add_tools(mcp)
308327

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)