Skip to content

Commit 9174ccf

Browse files
committed
feat: async encoder cache impl
1 parent b03fe26 commit 9174ccf

File tree

6 files changed

+454
-65
lines changed

6 files changed

+454
-65
lines changed

components/src/dynamo/common/memory/encoder_cache_manager.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
1010
Usage:
1111
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3) # 4GB
12-
12+
1313
# Store embedding
1414
cache.set("abc123", embedding_tensor)
15-
15+
1616
# Retrieve embedding
1717
tensor = cache.get("abc123") # Returns None if not found
1818
"""
@@ -29,10 +29,10 @@
2929
class EncoderCacheManager:
3030
"""
3131
LRU cache for encoder embeddings.
32-
32+
3333
Stores tensors keyed by content hash with automatic eviction
3434
when capacity is exceeded.
35-
35+
3636
Thread Safety:
3737
This class is NOT thread-safe. It is designed to run within a single
3838
thread (e.g., an asyncio event loop). All access must be from the same
@@ -43,55 +43,59 @@ class EncoderCacheManager:
4343
def __init__(self, capacity_bytes: int):
4444
"""
4545
Initialize the encoder cache.
46-
46+
4747
Args:
4848
capacity_bytes: Maximum cache capacity in bytes.
4949
"""
5050
if capacity_bytes <= 0:
5151
raise ValueError("capacity_bytes must be positive")
52-
52+
5353
self._cache: OrderedDict[str, torch.Tensor] = OrderedDict()
5454
self._capacity_bytes = capacity_bytes
5555
self._current_bytes = 0
56-
56+
5757
# Stats
5858
self._hits = 0
5959
self._misses = 0
60-
61-
logger.info(f"EncoderCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB")
60+
61+
logger.info(
62+
f"EncoderCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB"
63+
)
6264

6365
@staticmethod
6466
def _tensor_size(tensor: torch.Tensor) -> int:
6567
"""Calculate tensor size in bytes.
66-
68+
6769
Args:
6870
tensor: Must be a contiguous tensor.
69-
71+
7072
Returns:
7173
Size of the tensor in bytes.
72-
74+
7375
Raises:
7476
AssertionError: If tensor is not contiguous.
7577
"""
76-
assert tensor.is_contiguous(), "Tensor must be contiguous for accurate size calculation"
78+
assert (
79+
tensor.is_contiguous()
80+
), "Tensor must be contiguous for accurate size calculation"
7781
return tensor.element_size() * tensor.numel()
7882

7983
def get(self, key: str) -> Optional[torch.Tensor]:
8084
"""
8185
Get a tensor from the cache.
82-
86+
8387
If found, the entry is moved to the end (most recently used).
84-
88+
8589
Args:
8690
key: Cache key (typically content hash).
87-
91+
8892
Returns:
8993
The cached tensor, or None if not found.
9094
"""
9195
if key not in self._cache:
9296
self._misses += 1
9397
return None
94-
98+
9599
# Move to end (most recently used)
96100
self._cache.move_to_end(key)
97101
self._hits += 1
@@ -100,44 +104,46 @@ def get(self, key: str) -> Optional[torch.Tensor]:
100104
def set(self, key: str, tensor: torch.Tensor) -> bool:
101105
"""
102106
Store a tensor in the cache.
103-
107+
104108
If the key already exists, the old value is replaced.
105109
If adding the tensor would exceed capacity, LRU entries are evicted.
106110
If the tensor itself is larger than capacity, it is not stored.
107-
111+
108112
Args:
109113
key: Cache key (typically content hash).
110114
tensor: Tensor to cache.
111-
115+
112116
Returns:
113117
True if the tensor was stored, False if it was too large.
114118
"""
115119
size = self._tensor_size(tensor)
116-
120+
117121
# Don't cache if single tensor exceeds capacity
118122
if size > self._capacity_bytes:
119123
logger.warning(
120124
f"Tensor too large to cache: {size / 1024**2:.1f}MB > "
121125
f"{self._capacity_bytes / 1024**3:.2f}GB capacity"
122126
)
123127
return False
124-
128+
125129
# If key exists, remove old entry first
126130
if key in self._cache:
127131
old_tensor = self._cache.pop(key)
128132
self._current_bytes -= self._tensor_size(old_tensor)
129-
133+
130134
# Evict LRU entries until we have space
131135
while self._current_bytes + size > self._capacity_bytes and self._cache:
132136
evicted_key, evicted_tensor = self._cache.popitem(last=False)
133137
evicted_size = self._tensor_size(evicted_tensor)
134138
self._current_bytes -= evicted_size
135-
logger.debug(f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB")
136-
139+
logger.debug(
140+
f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB"
141+
)
142+
137143
# Store new entry
138144
self._cache[key] = tensor
139145
self._current_bytes += size
140-
146+
141147
logger.debug(
142148
f"Cached key={key[:16] if len(key) > 16 else key}, "
143149
f"size={size / 1024**2:.2f}MB, "
@@ -149,19 +155,21 @@ def set(self, key: str, tensor: torch.Tensor) -> bool:
149155
def stats(self) -> dict:
150156
"""
151157
Get cache statistics.
152-
158+
153159
Returns:
154160
Dictionary with cache stats including entries, memory usage,
155161
hit/miss counts, and hit rate.
156162
"""
157163
total_requests = self._hits + self._misses
158164
hit_rate = self._hits / total_requests if total_requests > 0 else 0.0
159-
165+
160166
return {
161167
"entries": len(self._cache),
162168
"current_bytes": self._current_bytes,
163169
"capacity_bytes": self._capacity_bytes,
164-
"utilization": self._current_bytes / self._capacity_bytes if self._capacity_bytes > 0 else 0,
170+
"utilization": self._current_bytes / self._capacity_bytes
171+
if self._capacity_bytes > 0
172+
else 0,
165173
"hits": self._hits,
166174
"misses": self._misses,
167175
"hit_rate": hit_rate,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Multimodal utilities for Dynamo components."""
5+
6+
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
7+
8+
__all__ = ["AsyncEncoderCache"]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Async Encoder Cache
6+
7+
Async wrapper over EncoderCacheManager with request coalescing.
8+
Prevents duplicate encoding when multiple requests arrive for the same content.
9+
10+
Usage:
11+
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3)
12+
async_cache = AsyncEncoderCache(cache)
13+
14+
# Get from cache or compute with coalescing
15+
tensor = await async_cache.get_or_compute("hash123", encoder.encode)
16+
"""
17+
18+
import asyncio
19+
import logging
20+
from typing import Awaitable, Callable, Dict, Optional
21+
22+
import torch
23+
24+
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
def _suppress_unhandled_future_exception(future: asyncio.Future) -> None:
30+
"""
31+
Callback to prevent 'Future exception was never retrieved' warning.
32+
33+
When a Future has set_exception() called but no one awaits it (e.g., single
34+
caller that gets the exception via re-raise), asyncio warns. This callback
35+
retrieves the exception to suppress that warning.
36+
"""
37+
if future.done() and not future.cancelled():
38+
try:
39+
future.exception()
40+
except asyncio.CancelledError:
41+
pass
42+
43+
44+
class AsyncEncoderCache:
45+
"""
46+
Async wrapper with request coalescing over EncoderCacheManager.
47+
48+
Provides async get_or_compute that deduplicates concurrent requests
49+
for the same key, ensuring only one encoding runs at a time per key.
50+
51+
Thread Safety:
52+
This class is NOT thread-safe. It is designed to run within a single
53+
asyncio event loop. All access must be from the same thread.
54+
"""
55+
56+
def __init__(self, cache: EncoderCacheManager):
57+
"""
58+
Initialize the async encoder cache.
59+
60+
Args:
61+
cache: Underlying EncoderCacheManager for storage.
62+
"""
63+
self._cache = cache
64+
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {}
65+
66+
def get(self, key: str) -> Optional[torch.Tensor]:
67+
"""
68+
Synchronous get from underlying cache.
69+
70+
Args:
71+
key: Cache key.
72+
73+
Returns:
74+
Cached tensor or None if not found.
75+
"""
76+
return self._cache.get(key)
77+
78+
async def get_or_compute(
79+
self,
80+
key: str,
81+
compute_fn: Callable[[], Awaitable[torch.Tensor]],
82+
) -> torch.Tensor:
83+
"""
84+
Get from cache or compute with request coalescing.
85+
86+
If the key is in cache, returns immediately.
87+
If another coroutine is already computing this key, waits for that result.
88+
Otherwise, computes and caches the result.
89+
90+
Args:
91+
key: Cache key (typically content hash).
92+
compute_fn: Async function to compute the tensor if not cached.
93+
94+
Returns:
95+
The cached or computed tensor.
96+
97+
Raises:
98+
Exception: Re-raises any exception from compute_fn.
99+
"""
100+
# Check cache first
101+
cached = self._cache.get(key)
102+
if cached is not None:
103+
return cached
104+
105+
# Wait if already in-flight
106+
if key in self._in_flight:
107+
logger.debug(f"Waiting for in-flight computation: key={key[:16]}...")
108+
return await self._in_flight[key]
109+
110+
# Compute with coalescing
111+
future: asyncio.Future[torch.Tensor] = asyncio.Future()
112+
future.add_done_callback(_suppress_unhandled_future_exception)
113+
self._in_flight[key] = future
114+
try:
115+
tensor = await compute_fn()
116+
self._cache.set(key, tensor)
117+
future.set_result(tensor)
118+
return tensor
119+
except Exception as e:
120+
future.set_exception(e)
121+
raise
122+
finally:
123+
del self._in_flight[key]
124+
125+
@property
126+
def stats(self) -> dict:
127+
"""
128+
Get cache statistics from underlying cache.
129+
130+
Returns:
131+
Dictionary with cache stats.
132+
"""
133+
base_stats = self._cache.stats
134+
base_stats["in_flight"] = len(self._in_flight)
135+
return base_stats

0 commit comments

Comments
 (0)