Skip to content

Commit 801385d

Browse files
authored
fix: bound _refresh_locks with LRU eviction to prevent memory leak (#3968)
1 parent 64fbc52 commit 801385d

2 files changed

Lines changed: 49 additions & 4 deletions

File tree

src/fastmcp/server/auth/oauth_proxy/proxy.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import secrets
2323
import time
2424
from base64 import urlsafe_b64encode
25+
from collections import OrderedDict
2526
from typing import Any, Literal
2627
from urllib.parse import urlencode, urlparse, urlunparse
2728

@@ -92,6 +93,8 @@
9293

9394
logger = get_logger(__name__)
9495

96+
_REFRESH_LOCK_CACHE_SIZE = 10_000
97+
9598

9699
def _normalize_resource_url(url: str) -> str:
97100
"""Normalize a resource URL by removing query parameters and trailing slashes.
@@ -576,7 +579,7 @@ def __init__(
576579
# refresh the same token within a single process. Does not protect
577580
# against cross-process races in distributed deployments — those are
578581
# handled by re-reading from storage after refresh failure.
579-
self._refresh_locks: dict[str, anyio.Lock] = {}
582+
self._refresh_locks: OrderedDict[str, anyio.Lock] = OrderedDict()
580583

581584
logger.debug(
582585
"Initialized OAuth proxy provider with upstream server %s",
@@ -647,6 +650,18 @@ def _create_upstream_oauth_client(self) -> AsyncOAuth2Client:
647650
timeout=HTTP_TIMEOUT_SECONDS,
648651
)
649652

653+
def _get_refresh_lock(self, token_id: str) -> anyio.Lock:
654+
"""Get or create a per-token refresh lock, evicting LRU entries when at capacity."""
655+
lock = self._refresh_locks.get(token_id)
656+
if lock is None:
657+
lock = anyio.Lock()
658+
self._refresh_locks[token_id] = lock
659+
if len(self._refresh_locks) > _REFRESH_LOCK_CACHE_SIZE:
660+
self._refresh_locks.popitem(last=False)
661+
else:
662+
self._refresh_locks.move_to_end(token_id)
663+
return lock
664+
650665
# -------------------------------------------------------------------------
651666
# PKCE Helper Methods
652667
# -------------------------------------------------------------------------
@@ -1656,9 +1671,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig
16561671

16571672
# Advisory lock prevents concurrent requests from racing
16581673
# to refresh the same upstream token.
1659-
if token_id not in self._refresh_locks:
1660-
self._refresh_locks[token_id] = anyio.Lock()
1661-
lock = self._refresh_locks[token_id]
1674+
lock = self._get_refresh_lock(token_id)
16621675

16631676
async with lock:
16641677
# Re-read from storage — another task may have

tests/server/auth/oauth_proxy/test_tokens.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,38 @@ async def mock_get(key: str) -> UpstreamTokenSet | None:
907907
assert result is not None
908908
assert result.token == "refreshed-upstream-access"
909909

910+
def test_refresh_lock_cache_bounded(self, proxy, monkeypatch):
911+
"""_get_refresh_lock never grows beyond _REFRESH_LOCK_CACHE_SIZE."""
912+
monkeypatch.setattr(
913+
"fastmcp.server.auth.oauth_proxy.proxy._REFRESH_LOCK_CACHE_SIZE", 3
914+
)
915+
for i in range(10):
916+
proxy._get_refresh_lock(f"token-{i}")
917+
assert len(proxy._refresh_locks) == 3
918+
919+
def test_refresh_lock_lru_evicts_least_recently_used(self, proxy, monkeypatch):
920+
"""Touching an entry promotes it; eviction removes the oldest untouched entry."""
921+
monkeypatch.setattr(
922+
"fastmcp.server.auth.oauth_proxy.proxy._REFRESH_LOCK_CACHE_SIZE", 3
923+
)
924+
proxy._get_refresh_lock("a")
925+
proxy._get_refresh_lock("b")
926+
proxy._get_refresh_lock("c")
927+
# Touch "a" to move it to MRU position
928+
proxy._get_refresh_lock("a")
929+
# Adding "d" should evict "b" (oldest untouched)
930+
proxy._get_refresh_lock("d")
931+
assert "b" not in proxy._refresh_locks
932+
assert "a" in proxy._refresh_locks
933+
assert "c" in proxy._refresh_locks
934+
assert "d" in proxy._refresh_locks
935+
936+
def test_refresh_lock_same_token_returns_same_lock(self, proxy):
937+
"""Requesting the same token ID twice returns the same lock object."""
938+
lock1 = proxy._get_refresh_lock("tok")
939+
lock2 = proxy._get_refresh_lock("tok")
940+
assert lock1 is lock2
941+
910942
async def test_upstream_claims_propagated(self, proxy):
911943
jwt = await self._setup_session_with_claims(
912944
proxy, upstream_claims={"sub": "user-123"}

0 commit comments

Comments
 (0)