|
22 | 22 | import secrets |
23 | 23 | import time |
24 | 24 | from base64 import urlsafe_b64encode |
| 25 | +from collections import OrderedDict |
25 | 26 | from typing import Any, Literal |
26 | 27 | from urllib.parse import urlencode, urlparse, urlunparse |
27 | 28 |
|
|
92 | 93 |
|
93 | 94 | logger = get_logger(__name__) |
94 | 95 |
|
| 96 | +_REFRESH_LOCK_CACHE_SIZE = 10_000 |
| 97 | + |
95 | 98 |
|
96 | 99 | def _normalize_resource_url(url: str) -> str: |
97 | 100 | """Normalize a resource URL by removing query parameters and trailing slashes. |
@@ -576,7 +579,7 @@ def __init__( |
576 | 579 | # refresh the same token within a single process. Does not protect |
577 | 580 | # against cross-process races in distributed deployments — those are |
578 | 581 | # handled by re-reading from storage after refresh failure. |
579 | | - self._refresh_locks: dict[str, anyio.Lock] = {} |
| 582 | + self._refresh_locks: OrderedDict[str, anyio.Lock] = OrderedDict() |
580 | 583 |
|
581 | 584 | logger.debug( |
582 | 585 | "Initialized OAuth proxy provider with upstream server %s", |
@@ -647,6 +650,18 @@ def _create_upstream_oauth_client(self) -> AsyncOAuth2Client: |
647 | 650 | timeout=HTTP_TIMEOUT_SECONDS, |
648 | 651 | ) |
649 | 652 |
|
| 653 | + def _get_refresh_lock(self, token_id: str) -> anyio.Lock: |
| 654 | + """Get or create a per-token refresh lock, evicting LRU entries when at capacity.""" |
| 655 | + lock = self._refresh_locks.get(token_id) |
| 656 | + if lock is None: |
| 657 | + lock = anyio.Lock() |
| 658 | + self._refresh_locks[token_id] = lock |
| 659 | + if len(self._refresh_locks) > _REFRESH_LOCK_CACHE_SIZE: |
| 660 | + self._refresh_locks.popitem(last=False) |
| 661 | + else: |
| 662 | + self._refresh_locks.move_to_end(token_id) |
| 663 | + return lock |
| 664 | + |
650 | 665 | # ------------------------------------------------------------------------- |
651 | 666 | # PKCE Helper Methods |
652 | 667 | # ------------------------------------------------------------------------- |
@@ -1656,9 +1671,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig |
1656 | 1671 |
|
1657 | 1672 | # Advisory lock prevents concurrent requests from racing |
1658 | 1673 | # to refresh the same upstream token. |
1659 | | - if token_id not in self._refresh_locks: |
1660 | | - self._refresh_locks[token_id] = anyio.Lock() |
1661 | | - lock = self._refresh_locks[token_id] |
| 1674 | + lock = self._get_refresh_lock(token_id) |
1662 | 1675 |
|
1663 | 1676 | async with lock: |
1664 | 1677 | # Re-read from storage — another task may have |
|
0 commit comments