|
| 1 | +diff --git a/src/model.py b/src/model.py |
| 2 | +index 3f6e23b..d4228d2 100644 |
| 3 | +--- a/src/model.py |
| 4 | ++++ b/src/model.py |
| 5 | +@@ -42,6 +42,7 @@ from vllm.sampling_params import SamplingParams |
| 6 | + from vllm.utils import random_uuid |
| 7 | + |
| 8 | + from utils.metrics import VllmStatLogger |
| 9 | ++from utils.semantic_caching import SemanticCPUCache, SemanticCPUCacheConfig |
| 10 | + |
| 11 | + _VLLM_ENGINE_ARGS_FILENAME = "model.json" |
| 12 | + _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" |
| 13 | +@@ -130,6 +131,8 @@ class TritonPythonModel: |
| 14 | + ) |
| 15 | + self._shutdown_event = asyncio.Event() |
| 16 | + self._event_thread.start() |
| 17 | ++ config = SemanticCPUCacheConfig() |
| 18 | ++ self.semantic_cache = SemanticCPUCache(config=config) |
| 19 | + |
| 20 | + def init_engine(self): |
| 21 | + # Currently, Triton needs to use decoupled policy for asynchronously |
| 22 | +@@ -407,6 +410,18 @@ class TritonPythonModel: |
| 23 | + raise ValueError( |
| 24 | + "When streaming, `exclude_input_in_output` = False is not allowed." |
| 25 | + ) |
| 26 | ++ cache_hit = self.semantic_cache.get(prompt) |
| 27 | ++ if cache_hit: |
| 28 | ++ try: |
| 29 | ++ response_sender.send( |
| 30 | ++ self.create_response(cache_hit, prepend_input), |
| 31 | ++ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, |
| 32 | ++ ) |
| 33 | ++ if decrement_ongoing_request_count: |
| 34 | ++ self.ongoing_request_count -= 1 |
| 35 | ++ except Exception as err: |
| 36 | ++ print(f"Unexpected {err=} for prompt {prompt}") |
| 37 | ++ return None |
| 38 | + |
| 39 | + # Request parameters are not yet supported via |
| 40 | + # BLS. Provide an optional mechanism to receive serialized |
| 41 | +@@ -481,6 +496,7 @@ class TritonPythonModel: |
| 42 | + self.create_response(last_output, prepend_input), |
| 43 | + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, |
| 44 | + ) |
| 45 | ++ self.semantic_cache.set(prompt, last_output) |
| 46 | + |
| 47 | + except Exception as e: |
| 48 | + self.logger.log_error(f"[vllm] Error generating stream: {e}") |
| 49 | +diff --git a/src/utils/semantic_caching.py b/src/utils/semantic_caching.py |
| 50 | +new file mode 100644 |
| 51 | +index 0000000..457a163 |
| 52 | +--- /dev/null |
| 53 | ++++ b/src/utils/semantic_caching.py |
| 54 | +@@ -0,0 +1,184 @@ |
| 55 | ++import itertools |
| 56 | ++from dataclasses import dataclass |
| 57 | ++from typing import Any, Dict, Hashable, Optional |
| 58 | ++ |
| 59 | ++import faiss |
| 60 | ++import numpy as np |
| 61 | ++from sentence_transformers import SentenceTransformer |
| 62 | ++from theine import Cache |
| 63 | ++ |
| 64 | ++ |
| 65 | ++class KeyMapper: |
| 66 | ++ """ |
| 67 | ++ A class to manage bidirectional mapping between hashable keys and integer IDs. |
| 68 | ++ """ |
| 69 | ++ |
| 70 | ++ def __init__(self): |
| 71 | ++ self.hk_map: Dict[Hashable, int] = {} |
| 72 | ++ self.kh_map: Dict[int, Hashable] = {} |
| 73 | ++ self.counter = itertools.count() |
| 74 | ++ |
| 75 | ++ def add_key(self, key: Hashable): |
| 76 | ++ """ |
| 77 | ++ Add a new key to the mapper and return its assigned ID. |
| 78 | ++ |
| 79 | ++ Args: |
| 80 | ++ key (Hashable): The key to be added. |
| 81 | ++ |
| 82 | ++ Returns: |
| 83 | ++ int: The assigned ID for the key. |
| 84 | ++ """ |
| 85 | ++ if key in self.hk_map.keys(): |
| 86 | ++ return None |
| 87 | ++ id = next(self.counter) |
| 88 | ++ self.hk_map[key] = id |
| 89 | ++ self.kh_map[id] = key |
| 90 | ++ return id |
| 91 | ++ |
| 92 | ++ def remove_key(self, key: Hashable): |
| 93 | ++ """ |
| 94 | ++ Remove key from the mapper and return its ID. |
| 95 | ++ |
| 96 | ++ Args: |
| 97 | ++ key (Hashable): The key to be removed. |
| 98 | ++ |
| 99 | ++ Returns: |
| 100 | ++ int: The ID for the removed key. |
| 101 | ++ """ |
| 102 | ++ id = self.hk_map.pop(key, None) |
| 103 | ++ if id is not None: |
| 104 | ++ self.kh_map.pop(id, None) |
| 105 | ++ return id |
| 106 | ++ return None |
| 107 | ++ |
| 108 | ++ def get_key(self, id: int): |
| 109 | ++ """ |
| 110 | ++ Retrieve the key associated with the given ID. |
| 111 | ++ |
| 112 | ++ Args: |
| 113 | ++ id (int): The ID to look up. |
| 114 | ++ |
| 115 | ++ Returns: |
| 116 | ++ Optional[Hashable]: The associated key, or None if not found. |
| 117 | ++ """ |
| 118 | ++ return self.kh_map.get(id) |
| 119 | ++ |
| 120 | ++ def get_id(self, key: Hashable): |
| 121 | ++ """ |
| 122 | ++ Retrieve the ID associated with the given key. |
| 123 | ++ |
| 124 | ++ Args: |
| 125 | ++ key (Hashable): The key to look up. |
| 126 | ++ |
| 127 | ++ Returns: |
| 128 | ++ Optional[int]: The associated ID, or None if not found. |
| 129 | ++ """ |
| 130 | ++ return self.hk_map.get(key) |
| 131 | ++ |
| 132 | ++ |
| 133 | ++@dataclass |
| 134 | ++class SemanticCPUCacheConfig: |
| 135 | ++ """ |
| 136 | ++ Configuration class for SemanticCPUCache. |
| 137 | ++ |
| 138 | ++ Attributes: |
| 139 | ++ cache (Any): The cache object to use. |
| 140 | ++ encoder (Any): The encoder object for embedding queries. |
| 141 | ++ index (Any): The index object for similarity search. |
| 142 | ++ threshold (float): The similarity threshold for considering a match. |
| 143 | ++ key_mapper (Any): The key mapper object for managing key-ID mappings. |
| 144 | ++ """ |
| 145 | ++ |
| 146 | ++ cache: Any = Cache(policy="lru", size=1000) |
| 147 | ++ encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| 148 | ++ index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384)) |
| 149 | ++ threshold: float = 0.25 |
| 150 | ++ key_mapper: Any = KeyMapper() |
| 151 | ++ |
| 152 | ++ |
| 153 | ++class SemanticCPUCache: |
| 154 | ++ """ |
| 155 | ++ Semantic cache implementation. |
| 156 | ++ """ |
| 157 | ++ |
| 158 | ++ def __init__(self, config: SemanticCPUCacheConfig): |
| 159 | ++ """ |
| 160 | ++ Initialize the SemanticCPUCache with the given configuration. |
| 161 | ++ |
| 162 | ++ Args: |
| 163 | ++ config (SemanticCPUCacheConfig): The configuration object. |
| 164 | ++ """ |
| 165 | ++ self.encoder = config.encoder |
| 166 | ++ self.index = config.index |
| 167 | ++ self.cache = config.cache |
| 168 | ++ self.key_map = config.key_mapper |
| 169 | ++ self.threshold = config.threshold |
| 170 | ++ |
| 171 | ++ def get(self, key: Hashable, default: Any = None) -> Any: |
| 172 | ++ """ |
| 173 | ++ Retrieve a value from the cache based on the given key. |
| 174 | ++ |
| 175 | ++ First, a similarity search is performed. If a similar key is found |
| 176 | ++ within the threshold, its associated value is returned. |
| 177 | ++ Otherwise, the default value is returned. |
| 178 | ++ |
| 179 | ++ Args: |
| 180 | ++ key (Hashable): The key to look up. |
| 181 | ++ default (Any, optional): The default value to return if no match is found. Defaults to None. |
| 182 | ++ |
| 183 | ++ Returns: |
| 184 | ++ Any: The retrieved value or the default value. |
| 185 | ++ """ |
| 186 | ++ if self.index.ntotal < 1: |
| 187 | ++ return default |
| 188 | ++ |
| 189 | ++ key_search = np.asarray([self.encoder.encode(key)]) |
| 190 | ++ dist, ind = self.index.search(key_search, 1) |
| 191 | ++ # print(dist[0][0]) |
| 192 | ++ |
| 193 | ++ if dist[0][0] > self.threshold: |
| 194 | ++ return default |
| 195 | ++ |
| 196 | ++ key_str = self.key_map.get_key(ind[0][0]) |
| 197 | ++ |
| 198 | ++ return self.cache.get(key=key_str, default=default) |
| 199 | ++ |
| 200 | ++ def set(self, key: Hashable, value: Any) -> Optional[str]: |
| 201 | ++ """ |
| 202 | ++ Set a key-value pair in the cache. |
| 203 | ++ |
| 204 | ++ This method adds the key to the key mapper, encodes the key, |
| 205 | ++ adds the encoded key to the index, and sets the value in the cache. |
| 206 | ++ |
| 207 | ++ |
| 208 | ++ Args: |
| 209 | ++ key (Hashable): The key to set. |
| 210 | ++ value (Any): The value to associate with the key. |
| 211 | ++ |
| 212 | ++ Returns: |
| 213 | ++ Optional[str]: The result of setting the value in the cache. |
| 214 | ++ |
| 215 | ++ Raises: |
| 216 | ++ AssertionError: If the key could not be added to the key mapper. |
| 217 | ++ """ |
| 218 | ++ id = self.key_map.add_key(key) |
| 219 | ++ if id is not None: |
| 220 | ++ # TODO: leaking implementation `add_with_ids`. add a layer |
| 221 | ++ self.index.add_with_ids( |
| 222 | ++ np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) |
| 223 | ++ ) |
| 224 | ++ |
| 225 | ++ evicted_key = self.cache.set(key, value) |
| 226 | ++ self._handle_evicted_key(evicted_key=evicted_key) |
| 227 | ++ |
| 228 | ++ return None |
| 229 | ++ |
| 230 | ++ def _handle_evicted_key(self, evicted_key: Hashable) -> None: |
| 231 | ++ if evicted_key is None: |
| 232 | ++ return None |
| 233 | ++ # TODO: extremely coupled, remove dependency on key id? |
| 234 | ++ evicted_id = self.key_map.remove_key(evicted_key) |
| 235 | ++ print(evicted_id) |
| 236 | ++ # TODO: leaking implementation `remove_ids`. add a layer |
| 237 | ++ self.index.remove_ids(np.array([evicted_id])) |
| 238 | ++ return None |
0 commit comments