Skip to content

Commit dd4de13

Browse files
committed
Added patch
1 parent cba4196 commit dd4de13

File tree

1 file changed

+238
-0
lines changed

1 file changed

+238
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
diff --git a/src/model.py b/src/model.py
2+
index 3f6e23b..d4228d2 100644
3+
--- a/src/model.py
4+
+++ b/src/model.py
5+
@@ -42,6 +42,7 @@ from vllm.sampling_params import SamplingParams
6+
from vllm.utils import random_uuid
7+
8+
from utils.metrics import VllmStatLogger
9+
+from utils.semantic_caching import SemanticCPUCache, SemanticCPUCacheConfig
10+
11+
_VLLM_ENGINE_ARGS_FILENAME = "model.json"
12+
_MULTI_LORA_ARGS_FILENAME = "multi_lora.json"
13+
@@ -130,6 +131,8 @@ class TritonPythonModel:
14+
)
15+
self._shutdown_event = asyncio.Event()
16+
self._event_thread.start()
17+
+ config = SemanticCPUCacheConfig()
18+
+ self.semantic_cache = SemanticCPUCache(config=config)
19+
20+
def init_engine(self):
21+
# Currently, Triton needs to use decoupled policy for asynchronously
22+
@@ -407,6 +410,18 @@ class TritonPythonModel:
23+
raise ValueError(
24+
"When streaming, `exclude_input_in_output` = False is not allowed."
25+
)
26+
+ cache_hit = self.semantic_cache.get(prompt)
27+
+ if cache_hit:
28+
+ try:
29+
+ response_sender.send(
30+
+ self.create_response(cache_hit, prepend_input),
31+
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
32+
+ )
33+
+ if decrement_ongoing_request_count:
34+
+ self.ongoing_request_count -= 1
35+
+ except Exception as err:
36+
+ print(f"Unexpected {err=} for prompt {prompt}")
37+
+ return None
38+
39+
# Request parameters are not yet supported via
40+
# BLS. Provide an optional mechanism to receive serialized
41+
@@ -481,6 +496,7 @@ class TritonPythonModel:
42+
self.create_response(last_output, prepend_input),
43+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
44+
)
45+
+ self.semantic_cache.set(prompt, last_output)
46+
47+
except Exception as e:
48+
self.logger.log_error(f"[vllm] Error generating stream: {e}")
49+
diff --git a/src/utils/semantic_caching.py b/src/utils/semantic_caching.py
50+
new file mode 100644
51+
index 0000000..457a163
52+
--- /dev/null
53+
+++ b/src/utils/semantic_caching.py
54+
@@ -0,0 +1,184 @@
55+
+import itertools
56+
+from dataclasses import dataclass
57+
+from typing import Any, Dict, Hashable, Optional
58+
+
59+
+import faiss
60+
+import numpy as np
61+
+from sentence_transformers import SentenceTransformer
62+
+from theine import Cache
63+
+
64+
+
65+
+class KeyMapper:
66+
+ """
67+
+ A class to manage bidirectional mapping between hashable keys and integer IDs.
68+
+ """
69+
+
70+
+ def __init__(self):
71+
+ self.hk_map: Dict[Hashable, int] = {}
72+
+ self.kh_map: Dict[int, Hashable] = {}
73+
+ self.counter = itertools.count()
74+
+
75+
+ def add_key(self, key: Hashable):
76+
+ """
77+
+ Add a new key to the mapper and return its assigned ID.
78+
+
79+
+ Args:
80+
+ key (Hashable): The key to be added.
81+
+
82+
+ Returns:
83+
+ int: The assigned ID for the key.
84+
+ """
85+
+ if key in self.hk_map.keys():
86+
+ return None
87+
+ id = next(self.counter)
88+
+ self.hk_map[key] = id
89+
+ self.kh_map[id] = key
90+
+ return id
91+
+
92+
+ def remove_key(self, key: Hashable):
93+
+ """
94+
+ Remove key from the mapper and return its ID.
95+
+
96+
+ Args:
97+
+ key (Hashable): The key to be removed.
98+
+
99+
+ Returns:
100+
+ int: The ID for the removed key.
101+
+ """
102+
+ id = self.hk_map.pop(key, None)
103+
+ if id is not None:
104+
+ self.kh_map.pop(id, None)
105+
+ return id
106+
+ return None
107+
+
108+
+ def get_key(self, id: int):
109+
+ """
110+
+ Retrieve the key associated with the given ID.
111+
+
112+
+ Args:
113+
+ id (int): The ID to look up.
114+
+
115+
+ Returns:
116+
+ Optional[Hashable]: The associated key, or None if not found.
117+
+ """
118+
+ return self.kh_map.get(id)
119+
+
120+
+ def get_id(self, key: Hashable):
121+
+ """
122+
+ Retrieve the ID associated with the given key.
123+
+
124+
+ Args:
125+
+ key (Hashable): The key to look up.
126+
+
127+
+ Returns:
128+
+ Optional[int]: The associated ID, or None if not found.
129+
+ """
130+
+ return self.hk_map.get(key)
131+
+
132+
+
133+
+@dataclass
134+
+class SemanticCPUCacheConfig:
135+
+ """
136+
+ Configuration class for SemanticCPUCache.
137+
+
138+
+ Attributes:
139+
+ cache (Any): The cache object to use.
140+
+ encoder (Any): The encoder object for embedding queries.
141+
+ index (Any): The index object for similarity search.
142+
+ threshold (float): The similarity threshold for considering a match.
143+
+ key_mapper (Any): The key mapper object for managing key-ID mappings.
144+
+ """
145+
+
146+
+ cache: Any = Cache(policy="lru", size=1000)
147+
+ encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
148+
+ index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384))
149+
+ threshold: float = 0.25
150+
+ key_mapper: Any = KeyMapper()
151+
+
152+
+
153+
+class SemanticCPUCache:
154+
+ """
155+
+ Semantic cache implementation.
156+
+ """
157+
+
158+
+ def __init__(self, config: SemanticCPUCacheConfig):
159+
+ """
160+
+ Initialize the SemanticCPUCache with the given configuration.
161+
+
162+
+ Args:
163+
+ config (SemanticCPUCacheConfig): The configuration object.
164+
+ """
165+
+ self.encoder = config.encoder
166+
+ self.index = config.index
167+
+ self.cache = config.cache
168+
+ self.key_map = config.key_mapper
169+
+ self.threshold = config.threshold
170+
+
171+
+ def get(self, key: Hashable, default: Any = None) -> Any:
172+
+ """
173+
+ Retrieve a value from the cache based on the given key.
174+
+
175+
+ First, a similarity search is performed. If a similar key is found
176+
+ within the threshold, its associated value is returned.
177+
+ Otherwise, the default value is returned.
178+
+
179+
+ Args:
180+
+ key (Hashable): The key to look up.
181+
+ default (Any, optional): The default value to return if no match is found. Defaults to None.
182+
+
183+
+ Returns:
184+
+ Any: The retrieved value or the default value.
185+
+ """
186+
+ if self.index.ntotal < 1:
187+
+ return default
188+
+
189+
+ key_search = np.asarray([self.encoder.encode(key)])
190+
+ dist, ind = self.index.search(key_search, 1)
191+
+ # print(dist[0][0])
192+
+
193+
+ if dist[0][0] > self.threshold:
194+
+ return default
195+
+
196+
+ key_str = self.key_map.get_key(ind[0][0])
197+
+
198+
+ return self.cache.get(key=key_str, default=default)
199+
+
200+
+ def set(self, key: Hashable, value: Any) -> Optional[str]:
201+
+ """
202+
+ Set a key-value pair in the cache.
203+
+
204+
+ This method adds the key to the key mapper, encodes the key,
205+
+ adds the encoded key to the index, and sets the value in the cache.
206+
+
207+
+
208+
+ Args:
209+
+ key (Hashable): The key to set.
210+
+ value (Any): The value to associate with the key.
211+
+
212+
+ Returns:
213+
+ Optional[str]: The result of setting the value in the cache.
214+
+
215+
+ Raises:
216+
+ AssertionError: If the key could not be added to the key mapper.
217+
+ """
218+
+ id = self.key_map.add_key(key)
219+
+ if id is not None:
220+
+ # TODO: leaking implementation `add_with_ids`. add a layer
221+
+ self.index.add_with_ids(
222+
+ np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id])
223+
+ )
224+
+
225+
+ evicted_key = self.cache.set(key, value)
226+
+ self._handle_evicted_key(evicted_key=evicted_key)
227+
+
228+
+ return None
229+
+
230+
+ def _handle_evicted_key(self, evicted_key: Hashable) -> None:
231+
+ if evicted_key is None:
232+
+ return None
233+
+ # TODO: extremely coupled, remove dependency on key id?
234+
+ evicted_id = self.key_map.remove_key(evicted_key)
235+
+ print(evicted_id)
236+
+ # TODO: leaking implementation `remove_ids`. add a layer
237+
+ self.index.remove_ids(np.array([evicted_id]))
238+
+ return None

0 commit comments

Comments
 (0)