99
1010Usage:
1111 cache = EncoderCacheManager(capacity_bytes=4 * 1024**3) # 4GB
12-
12+
1313 # Store embedding
1414 cache.set("abc123", embedding_tensor)
15-
15+
1616 # Retrieve embedding
1717 tensor = cache.get("abc123") # Returns None if not found
1818"""
2929class EncoderCacheManager :
3030 """
3131 LRU cache for encoder embeddings.
32-
32+
3333 Stores tensors keyed by content hash with automatic eviction
3434 when capacity is exceeded.
35-
35+
3636 Thread Safety:
3737 This class is NOT thread-safe. It is designed to run within a single
3838 thread (e.g., an asyncio event loop). All access must be from the same
@@ -43,55 +43,59 @@ class EncoderCacheManager:
4343 def __init__ (self , capacity_bytes : int ):
4444 """
4545 Initialize the encoder cache.
46-
46+
4747 Args:
4848 capacity_bytes: Maximum cache capacity in bytes.
4949 """
5050 if capacity_bytes <= 0 :
5151 raise ValueError ("capacity_bytes must be positive" )
52-
52+
5353 self ._cache : OrderedDict [str , torch .Tensor ] = OrderedDict ()
5454 self ._capacity_bytes = capacity_bytes
5555 self ._current_bytes = 0
56-
56+
5757 # Stats
5858 self ._hits = 0
5959 self ._misses = 0
60-
61- logger .info (f"EncoderCacheManager initialized: capacity={ capacity_bytes / 1024 ** 3 :.2f} GB" )
60+
61+ logger .info (
62+ f"EncoderCacheManager initialized: capacity={ capacity_bytes / 1024 ** 3 :.2f} GB"
63+ )
6264
6365 @staticmethod
6466 def _tensor_size (tensor : torch .Tensor ) -> int :
6567 """Calculate tensor size in bytes.
66-
68+
6769 Args:
6870 tensor: Must be a contiguous tensor.
69-
71+
7072 Returns:
7173 Size of the tensor in bytes.
72-
74+
7375 Raises:
7476 AssertionError: If tensor is not contiguous.
7577 """
76- assert tensor .is_contiguous (), "Tensor must be contiguous for accurate size calculation"
78+ assert (
79+ tensor .is_contiguous ()
80+ ), "Tensor must be contiguous for accurate size calculation"
7781 return tensor .element_size () * tensor .numel ()
7882
7983 def get (self , key : str ) -> Optional [torch .Tensor ]:
8084 """
8185 Get a tensor from the cache.
82-
86+
8387 If found, the entry is moved to the end (most recently used).
84-
88+
8589 Args:
8690 key: Cache key (typically content hash).
87-
91+
8892 Returns:
8993 The cached tensor, or None if not found.
9094 """
9195 if key not in self ._cache :
9296 self ._misses += 1
9397 return None
94-
98+
9599 # Move to end (most recently used)
96100 self ._cache .move_to_end (key )
97101 self ._hits += 1
@@ -100,44 +104,46 @@ def get(self, key: str) -> Optional[torch.Tensor]:
100104 def set (self , key : str , tensor : torch .Tensor ) -> bool :
101105 """
102106 Store a tensor in the cache.
103-
107+
104108 If the key already exists, the old value is replaced.
105109 If adding the tensor would exceed capacity, LRU entries are evicted.
106110 If the tensor itself is larger than capacity, it is not stored.
107-
111+
108112 Args:
109113 key: Cache key (typically content hash).
110114 tensor: Tensor to cache.
111-
115+
112116 Returns:
113117 True if the tensor was stored, False if it was too large.
114118 """
115119 size = self ._tensor_size (tensor )
116-
120+
117121 # Don't cache if single tensor exceeds capacity
118122 if size > self ._capacity_bytes :
119123 logger .warning (
120124 f"Tensor too large to cache: { size / 1024 ** 2 :.1f} MB > "
121125 f"{ self ._capacity_bytes / 1024 ** 3 :.2f} GB capacity"
122126 )
123127 return False
124-
128+
125129 # If key exists, remove old entry first
126130 if key in self ._cache :
127131 old_tensor = self ._cache .pop (key )
128132 self ._current_bytes -= self ._tensor_size (old_tensor )
129-
133+
130134 # Evict LRU entries until we have space
131135 while self ._current_bytes + size > self ._capacity_bytes and self ._cache :
132136 evicted_key , evicted_tensor = self ._cache .popitem (last = False )
133137 evicted_size = self ._tensor_size (evicted_tensor )
134138 self ._current_bytes -= evicted_size
135- logger .debug (f"Evicted key={ evicted_key [:16 ]} ..., size={ evicted_size / 1024 ** 2 :.2f} MB" )
136-
139+ logger .debug (
140+ f"Evicted key={ evicted_key [:16 ]} ..., size={ evicted_size / 1024 ** 2 :.2f} MB"
141+ )
142+
137143 # Store new entry
138144 self ._cache [key ] = tensor
139145 self ._current_bytes += size
140-
146+
141147 logger .debug (
142148 f"Cached key={ key [:16 ] if len (key ) > 16 else key } , "
143149 f"size={ size / 1024 ** 2 :.2f} MB, "
@@ -149,19 +155,21 @@ def set(self, key: str, tensor: torch.Tensor) -> bool:
149155 def stats (self ) -> dict :
150156 """
151157 Get cache statistics.
152-
158+
153159 Returns:
154160 Dictionary with cache stats including entries, memory usage,
155161 hit/miss counts, and hit rate.
156162 """
157163 total_requests = self ._hits + self ._misses
158164 hit_rate = self ._hits / total_requests if total_requests > 0 else 0.0
159-
165+
160166 return {
161167 "entries" : len (self ._cache ),
162168 "current_bytes" : self ._current_bytes ,
163169 "capacity_bytes" : self ._capacity_bytes ,
164- "utilization" : self ._current_bytes / self ._capacity_bytes if self ._capacity_bytes > 0 else 0 ,
170+ "utilization" : self ._current_bytes / self ._capacity_bytes
171+ if self ._capacity_bytes > 0
172+ else 0 ,
165173 "hits" : self ._hits ,
166174 "misses" : self ._misses ,
167175 "hit_rate" : hit_rate ,
0 commit comments