|
3 | 3 | from datetime import datetime, timedelta, timezone |
4 | 4 | from typing import Any, ClassVar, Protocol |
5 | 5 |
|
| 6 | +from diskcache import Cache as DiskCacheClient |
6 | 7 | from mcp.types import CallToolRequestParams, ContentBlock |
7 | 8 | from pydantic import BaseModel, ConfigDict |
8 | | -from typing_extensions import Self |
| 9 | +from typing_extensions import Self, overload, runtime_checkable |
9 | 10 |
|
10 | 11 | from fastmcp.server.middleware.middleware import CallNext, Middleware, MiddlewareContext |
11 | 12 | from fastmcp.tools.tool import ToolResult |
12 | 13 |
|
13 | 14 |
|
| 15 | +ONE_HOUR_IN_SECONDS = 3600 |
| 16 | +ONE_GB_IN_BYTES = 1024 * 1024 * 1024 |
| 17 | + |
| 18 | + |
14 | 19 | class CacheEntry(BaseModel): |
15 | 20 | """A cache entry.""" |
16 | 21 |
|
@@ -48,18 +53,49 @@ def from_tool_result(cls, key: str, value: ToolResult, ttl: int) -> Self: |
48 | 53 | ) |
49 | 54 |
|
50 | 55 |
|
| 56 | +@runtime_checkable |
51 | 57 | class CacheProtocol(Protocol): |
52 | 58 | """A protocol for a cache client.""" |
53 | 59 |
|
54 | | - async def get(self, key: str) -> ToolResult | None: ... |
| 60 | + async def get(self, key: str) -> ToolResult | None: |
| 61 | + """Get a value from the cache.""" |
| 62 | + |
| 63 | + async def set(self, key: str, value: ToolResult, ttl: int) -> None: |
| 64 | + """Set a value in the cache.""" |
| 65 | + |
| 66 | + async def delete(self, key: str) -> None: |
| 67 | + """Delete a value from the cache.""" |
| 68 | + |
| 69 | + |
| 70 | +class DiskCache(CacheProtocol): |
| 71 | + """A caching client that uses the DiskCache library to cache to disk.""" |
| 72 | + |
| 73 | + @overload |
| 74 | + def __init__(self, disk_cache: DiskCacheClient): |
| 75 | + """Initialize the disk cache with a diskcache client.""" |
55 | 76 |
|
56 | | - async def set(self, key: str, value: ToolResult, ttl: int) -> None: ... |
| 77 | + @overload |
| 78 | + def __init__(self, path: str, size_limit: int = ONE_GB_IN_BYTES): |
| 79 | + """Initialize a 1GB disk cache at the provided path.""" |
57 | 80 |
|
58 | | - async def delete(self, key: str) -> None: ... |
| 81 | + def __init__( |
| 82 | + self, |
| 83 | + disk_cache: DiskCacheClient | None = None, |
| 84 | + path: str | None = None, |
| 85 | + size_limit: int = ONE_GB_IN_BYTES, |
| 86 | + ): |
| 87 | + self._cache = disk_cache or DiskCacheClient( |
| 88 | + directory=path, size_limit=size_limit |
| 89 | + ) |
59 | 90 |
|
60 | | - async def setup(self) -> None: ... |
| 91 | + async def get(self, key: str) -> ToolResult | None: |
| 92 | + return self._cache.get(key) |
61 | 93 |
|
62 | | - async def clear(self) -> None: ... |
| 94 | + async def set(self, key: str, value: ToolResult, ttl: int) -> None: |
| 95 | + self._cache.set(key, value, expire=ttl) |
| 96 | + |
| 97 | + async def delete(self, key: str) -> None: |
| 98 | + self._cache.delete(key) |
63 | 99 |
|
64 | 100 |
|
65 | 101 | class InMemoryCache(CacheProtocol): |
@@ -121,14 +157,24 @@ class ResponseCachingMiddleware(Middleware): |
121 | 157 |
|
122 | 158 | def __init__( |
123 | 159 | self, |
124 | | - cache_backend: CacheProtocol, |
| 160 | + cache_backend: CacheProtocol | None = None, |
125 | 161 | included_tools: list[str] | None = None, |
126 | 162 | excluded_tools: list[str] | None = None, |
127 | | - default_ttl: int = 3600, |
| 163 | + default_ttl: int = ONE_HOUR_IN_SECONDS, |
128 | 164 | ): |
| 165 | + """Initialize the response caching middleware. |
| 166 | +
|
| 167 | + Args: |
| 168 | + cache_backend: The cache backend to use. If None, an in-memory cache is used. |
| 169 | + included_tools: The tools to cache responses from. If None, all tools are cached. |
| 170 | + excluded_tools: The tools to not cache responses from. If None, no tools are excluded. |
| 171 | + default_ttl: The default TTL for cached responses. Defaults to one hour. |
| 172 | + """ |
129 | 173 | self._default_ttl = default_ttl |
130 | | - self._backend = cache_backend |
| 174 | + self._backend = cache_backend or InMemoryCache() |
| 175 | + |
131 | 176 | self._stats = CacheStats(hits=0, misses=0) |
| 177 | + |
132 | 178 | self._included_tools = included_tools |
133 | 179 | self._excluded_tools = excluded_tools |
134 | 180 |
|
|
0 commit comments