diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md new file mode 100644 index 00000000..cbdf363e --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -0,0 +1,359 @@ + + +# Semantic Caching + +When deploying large language models (LLMs) or LLM-based workflows +there are two key factors to consider: the performance and cost-efficiency +of your application. Generating language model outputs requires significant +computational resources, for example GPU time, memory usage, and other +infrastructure costs. These resource-intensive requirements create a +pressing need for optimization strategies that can maintain +high-quality outputs while minimizing operational expenses. + +Semantic caching emerges as a powerful solution to reduce computational costs +for LLM-based applications. + +## Definition and Benefits + +**_Semantic caching_** is a caching mechanism that takes into account +the semantics of the incoming request, rather than just the raw data itself. +It goes beyond simple key-value pairs and considers the content or +context of the data. + +This approach offers several benefits including, but not limited to: + ++ **Cost Optimization** + + - Semantic caching can substantially reduce operational expenses associated + with LLM deployments. By storing and reusing responses for semantically + similar queries, it minimizes the number of actual LLM calls required. + ++ **Reduced Latency** + + - One of the primary benefits of semantic caching is its ability to + significantly improve response times. By retrieving cached responses for + similar queries, the system can bypass the need for full model inference, + resulting in reduced latency. + ++ **Increased Throughput** + + - Semantic caching allows for more efficient utilization of computational + resources. By serving cached responses for similar queries, it reduces the + load on infrastructure components. This efficiency enables the system + to handle a higher volume of requests with the same hardware, effectively + increasing throughput. + ++ **Scalability** + + - As the user base and the volume of queries grow, the probability of cache + hits increases, provided that there is adequate storage and resources + available to support this scaling. The improved resource efficiency and + reduced computational demands allows applications to serve more users + without a proportional increase in infrastructure costs. + ++ **Consistency in Responses** + + - For certain applications, maintaining consistency in responses to + similar queries can be beneficial. Semantic caching ensures that analogous + questions receive uniform answers, which can be particularly useful + in scenarios like customer service or educational applications. + +## Sample Reference Implementation + +In this tutorial we provide a reference implementation for a Semantic Cache in +[semantic_caching.py](./artifacts/semantic_caching.py). There are 3 key +dependencies: +* [SentenceTransformer](https://sbert.net/): a Python framework for computing +dense vector representations (embeddings) of sentences, paragraphs, and images. + - We use this library and `all-MiniLM-L6-v2` in particular to convert + incoming prompt into an embedding, enabling semantic comparison. + - Alternatives include [semantic search models](https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#semantic-search-models), + OpenAI Embeddings, etc. +* [Faiss](https://github.com/facebookresearch/faiss/wiki): an open-source library +developed by Facebook AI Research for efficient similarity search and +clustering of dense vectors. + - This library is used for the embedding store and extracting the most + similar embedded prompt from the cached requests (or from the index store). + - This is a mighty library with a great variety of CPU and GPU accelerated + algorithms. + - Alternatives include [annoy](https://github.com/spotify/annoy), or + [cuVS](https://github.com/rapidsai/cuvs). However, note that cuVS already + has an integration in Faiss, more on this can be found [here](https://docs.rapids.ai/api/cuvs/nightly/integrations/faiss/). +* [Theine](https://github.com/Yiling-J/theine): High performance in-memory +cache. + - We will use it as our exact match cache backend. After the most similar + prompt is identified, the corresponding cached response is extracted from + the cache. This library supports multiple eviction policies, in this + tutorial we use "LRU". + - One may also look into [MemCached](https://memcached.org/about) as a + potential alternative. + +Provided [script](./artifacts/semantic_caching.py) is heavily annotated and we +encourage users to look through the code to gain better clarity in all +the necessary stages. + +## Incorporating Semantic Cache into your workflow + +For this tutorial, we'll use the [vllm backend](https://github.com/triton-inference-server/vllm_backend) +as our example, focusing on demonstrating how to cache responses for the +non-streaming case. The principles covered here can be extended to handle +streaming scenarios as well. + +### Customising vLLM Backend + +First, let's start by cloning Triton's vllm backend repository. This will +provide the necessary codebase to implement our semantic caching example. + +```bash +git clone https://github.com/triton-inference-server/vllm_backend.git +cd vllm_backend +``` + +With the repository successfully cloned, the next step is to apply all +necessary modifications. To simplify this process, we've prepared a +[semantic_cache.patch](tutorials/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch) +that consolidates all changes into a single step: + +```bash +curl https://raw.githubusercontent.com/triton-inference-server/tutorials/refs/heads/main/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch | git apply -v +``` + +If you're eager to start using Triton with the optimized vLLM backend, +you can skip ahead to the +[Launching Triton with Optimized vLLM Backend](#launching-triton-with-optimized-vllm-backend) +section. However, for those interested in understanding the specifics, +let's explore what this patch includes. + +The patch introduces a new script, +[semantic_caching.py](./artifacts/semantic_caching.py), which is added to the +appropriate directory. This script implements the core logic for our +semantic caching functionality. + +Next, the patch integrates semantic caching into the model. Let's walk through +these changes step-by-step. + +Firstly, it imports the necessary classes from +[semantic_caching.py](./artifacts/semantic_caching.py) into the codebase: + +```diff +... + +from utils.metrics import VllmStatLogger ++from utils.semantic_caching import SemanticCPUCacheConfig, SemanticCPUCache +``` + +Next, it sets up the semantic cache during the initialization step. +This setup will prepare your model to utilize semantic caching during +its operations. + +```diff + def initialize(self, args): + self.args = args + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + ... + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._event_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._event_thread.start() ++ config = SemanticCPUCacheConfig() ++ self.semantic_cache = SemanticCPUCache(config=config) + +``` + +Finally, the patch incorporates logic to query and update the semantic cache +during request processing. This ensures that cached responses are efficiently +utilized whenever possible. + +```diff + async def generate(self, request): + ... + try: + request_id = random_uuid() + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] + ... + + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) ++ cache_hit = self.semantic_cache.get(prompt) ++ if cache_hit: ++ try: ++ response_sender.send( ++ self.create_response(cache_hit, prepend_input), ++ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ++ ) ++ if decrement_ongoing_request_count: ++ self.ongoing_request_count -= 1 ++ except Exception as err: ++ print(f"Unexpected {err=} for prompt {prompt}") ++ return None + ... + + async for output in response_iterator: + ... + + last_output = output + + if not stream: + response_sender.send( + self.create_response(last_output, prepend_input), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) ++ self.semantic_cache.set(prompt, last_output) + +``` + +### Launching Triton with Optimized vLLM Backend + +To evaluate or optimized vllm backend, let's start vllm docker container and +mount our implementation to `/opt/tritonserver/backends/vllm`. We'll +also mount sample model repository, provided in +`vllm_backend/samples/model_repository`. Feel free to set up your own. +Use the following docker command to start Triton's vllm docker container, +but make sure to specify proper paths to the cloned `vllm_backend` +repository and replace `` with the latest release of Triton. + +```bash +docker run --gpus all -it --net=host --rm \ + --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 \ + -v /path/to/vllm_backend/src/:/opt/tritonserver/backends/vllm \ + -v /path/to/vllm_backend/samples/model_repository:/workspace/model_repository \ + -w /workspace \ + nvcr.io/nvidia/tritonserver:-vllm-python-py3 +``` + +When inside the container, make sure to install required dependencies: +```bash +pip install sentence_transformers faiss_gpu theine +``` + +Finally, let's launch Triton +```bash +tritonserver --model-repository=model_repository/ +``` + +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. + +``` +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +### Evaluation + +After you [start Triton](#launching-triton-with-optimized-vllm-backend) +with the sample model_repository, you can quickly run your first inference +request with the +[generate endpoint](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md). + +We'll also time this query: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "Tell me, how do I create model repository for Triton Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m1.128s +user 0m0.000s +sys 0m0.015s +``` + +Now, let's try a different response, but keep the semantics: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How do I set up model repository for Triton Inference Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m0.038s +user 0m0.000s +sys 0m0.017s +``` + +Let's try one more: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How model repository should be set up for Triton Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m0.059s +user 0m0.016s +sys 0m0.000s +``` + +Clearly, the latter 2 requests are semantically similar to the first one, which +resulted in a cache hit scenario, which reduced the latency of our model from +approx 1.1s to the average of 0.048s per request. + +## Current Limitations + +* The current implementation of the Semantic Cache only considers the prompt +itself for cache hits, without accounting for additional request parameters +such as `max_tokens` and `temperature`. As a result, these parameters are not +included in the cache hit evaluation, which may affect the accuracy of cached +responses when different configurations are used. + +* Semantic Cache effectiveness is heavily reliant on the choice of embedding +model and application context. For instance, queries like "How to set up model +repository for Triton Inference Server?" and "How not to set up model +repository for Triton Inference Server?" may have high cosine similarity +despite differing semantically. This makes it challenging to set an optimal +threshold for cache hits, as a narrow similarity range might exclude useful +cache entries. + +## Interested in This Feature? + +While this reference implementation provides a glimpse into the potential +of semantic caching, it's important to note that it's not an officially +supported feature in Triton Inference Server. + +We value your input! If you're interested in seeing semantic caching as a +supported feature in future releases, we invite you to join the ongoing +[discussion](https://github.com/triton-inference-server/server/discussions/7742). +Provide details about why you think semantic caching would +be valuable for your use case. Your feedback helps shape our product roadmap, +and we appreciate your contributions to making our software better for everyone. \ No newline at end of file diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch new file mode 100644 index 00000000..5df4ceaf --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch @@ -0,0 +1,238 @@ +diff --git a/src/model.py b/src/model.py +index 3f6e23b..d4228d2 100644 +--- a/src/model.py ++++ b/src/model.py +@@ -42,6 +42,7 @@ from vllm.sampling_params import SamplingParams + from vllm.utils import random_uuid + + from utils.metrics import VllmStatLogger ++from utils.semantic_caching import SemanticCPUCache, SemanticCPUCacheConfig + + _VLLM_ENGINE_ARGS_FILENAME = "model.json" + _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" +@@ -130,6 +131,8 @@ class TritonPythonModel: + ) + self._shutdown_event = asyncio.Event() + self._event_thread.start() ++ config = SemanticCPUCacheConfig() ++ self.semantic_cache = SemanticCPUCache(config=config) + + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously +@@ -407,6 +410,18 @@ class TritonPythonModel: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) ++ cache_hit = self.semantic_cache.get(prompt) ++ if cache_hit: ++ try: ++ response_sender.send( ++ self.create_response(cache_hit, prepend_input), ++ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ++ ) ++ if decrement_ongoing_request_count: ++ self.ongoing_request_count -= 1 ++ except Exception as err: ++ print(f"Unexpected {err=} for prompt {prompt}") ++ return None + + # Request parameters are not yet supported via + # BLS. Provide an optional mechanism to receive serialized +@@ -481,6 +496,7 @@ class TritonPythonModel: + self.create_response(last_output, prepend_input), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) ++ self.semantic_cache.set(prompt, last_output) + + except Exception as e: + self.logger.log_error(f"[vllm] Error generating stream: {e}") +diff --git a/src/utils/semantic_caching.py b/src/utils/semantic_caching.py +new file mode 100644 +index 0000000..457a163 +--- /dev/null ++++ b/src/utils/semantic_caching.py +@@ -0,0 +1,184 @@ ++import itertools ++from dataclasses import dataclass ++from typing import Any, Dict, Hashable, Optional ++ ++import faiss ++import numpy as np ++from sentence_transformers import SentenceTransformer ++from theine import Cache ++ ++ ++class KeyMapper: ++ """ ++ A class to manage bidirectional mapping between hashable keys and integer IDs. ++ """ ++ ++ def __init__(self): ++ self.hk_map: Dict[Hashable, int] = {} ++ self.kh_map: Dict[int, Hashable] = {} ++ self.counter = itertools.count() ++ ++ def add_key(self, key: Hashable): ++ """ ++ Add a new key to the mapper and return its assigned ID. ++ ++ Args: ++ key (Hashable): The key to be added. ++ ++ Returns: ++ int: The assigned ID for the key. ++ """ ++ if key in self.hk_map.keys(): ++ return None ++ id = next(self.counter) ++ self.hk_map[key] = id ++ self.kh_map[id] = key ++ return id ++ ++ def remove_key(self, key: Hashable): ++ """ ++ Remove key from the mapper and return its ID. ++ ++ Args: ++ key (Hashable): The key to be removed. ++ ++ Returns: ++ int: The ID for the removed key. ++ """ ++ id = self.hk_map.pop(key, None) ++ if id is not None: ++ self.kh_map.pop(id, None) ++ return id ++ return None ++ ++ def get_key(self, id: int): ++ """ ++ Retrieve the key associated with the given ID. ++ ++ Args: ++ id (int): The ID to look up. ++ ++ Returns: ++ Optional[Hashable]: The associated key, or None if not found. ++ """ ++ return self.kh_map.get(id) ++ ++ def get_id(self, key: Hashable): ++ """ ++ Retrieve the ID associated with the given key. ++ ++ Args: ++ key (Hashable): The key to look up. ++ ++ Returns: ++ Optional[int]: The associated ID, or None if not found. ++ """ ++ return self.hk_map.get(key) ++ ++ ++@dataclass ++class SemanticCPUCacheConfig: ++ """ ++ Configuration class for SemanticCPUCache. ++ ++ Attributes: ++ cache (Any): The cache object to use. ++ encoder (Any): The encoder object for embedding queries. ++ index (Any): The index object for similarity search. ++ threshold (float): The similarity threshold for considering a match. ++ key_mapper (Any): The key mapper object for managing key-ID mappings. ++ """ ++ ++ cache: Any = Cache(policy="lru", size=1000) ++ encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ++ index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384)) ++ threshold: float = 0.25 ++ key_mapper: Any = KeyMapper() ++ ++ ++class SemanticCPUCache: ++ """ ++ Semantic cache implementation. ++ """ ++ ++ def __init__(self, config: SemanticCPUCacheConfig): ++ """ ++ Initialize the SemanticCPUCache with the given configuration. ++ ++ Args: ++ config (SemanticCPUCacheConfig): The configuration object. ++ """ ++ self.encoder = config.encoder ++ self.index = config.index ++ self.cache = config.cache ++ self.key_map = config.key_mapper ++ self.threshold = config.threshold ++ ++ def get(self, key: Hashable, default: Any = None) -> Any: ++ """ ++ Retrieve a value from the cache based on the given key. ++ ++ First, a similarity search is performed. If a similar key is found ++ within the threshold, its associated value is returned. ++ Otherwise, the default value is returned. ++ ++ Args: ++ key (Hashable): The key to look up. ++ default (Any, optional): The default value to return if no match is found. Defaults to None. ++ ++ Returns: ++ Any: The retrieved value or the default value. ++ """ ++ if self.index.ntotal < 1: ++ return default ++ ++ key_search = np.asarray([self.encoder.encode(key)]) ++ dist, ind = self.index.search(key_search, 1) ++ # print(dist[0][0]) ++ ++ if dist[0][0] > self.threshold: ++ return default ++ ++ key_str = self.key_map.get_key(ind[0][0]) ++ ++ return self.cache.get(key=key_str, default=default) ++ ++ def set(self, key: Hashable, value: Any) -> Optional[str]: ++ """ ++ Set a key-value pair in the cache. ++ ++ This method adds the key to the key mapper, encodes the key, ++ adds the encoded key to the index, and sets the value in the cache. ++ ++ ++ Args: ++ key (Hashable): The key to set. ++ value (Any): The value to associate with the key. ++ ++ Returns: ++ Optional[str]: The result of setting the value in the cache. ++ ++ Raises: ++ AssertionError: If the key could not be added to the key mapper. ++ """ ++ id = self.key_map.add_key(key) ++ if id is not None: ++ # TODO: leaking implementation `add_with_ids`. add a layer ++ self.index.add_with_ids( ++ np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) ++ ) ++ ++ evicted_key = self.cache.set(key, value) ++ self._handle_evicted_key(evicted_key=evicted_key) ++ ++ return None ++ ++ def _handle_evicted_key(self, evicted_key: Hashable) -> None: ++ if evicted_key is None: ++ return None ++ # TODO: extremely coupled, remove dependency on key id? ++ evicted_id = self.key_map.remove_key(evicted_key) ++ print(evicted_id) ++ # TODO: leaking implementation `remove_ids`. add a layer ++ self.index.remove_ids(np.array([evicted_id])) ++ return None diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py new file mode 100644 index 00000000..b4ec5618 --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py @@ -0,0 +1,232 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import itertools +from dataclasses import dataclass +from typing import Any, Dict, Hashable, Optional + +import faiss +import numpy as np +from sentence_transformers import SentenceTransformer +from theine import Cache + + +class KeyMapper: + """ + A class to manage bidirectional mapping between hashable keys and integer IDs. + """ + + def __init__(self): + self.hk_map: Dict[Hashable, int] = {} + self.kh_map: Dict[int, Hashable] = {} + self.counter = itertools.count() + + def add_key(self, key: Hashable): + """ + Add a new key to the mapper and return its assigned ID. + + Args: + key (Hashable): The key to be added. + + Returns: + int: The assigned ID for the key. + """ + if key in self.hk_map.keys(): + return None + id = next(self.counter) + self.hk_map[key] = id + self.kh_map[id] = key + return id + + def remove_key(self, key: Hashable): + """ + Remove key from the mapper and return its ID. + + Args: + key (Hashable): The key to be removed. + + Returns: + int: The ID for the removed key. + """ + id = self.hk_map.pop(key, None) + if id is not None: + self.kh_map.pop(id, None) + return id + return None + + def get_key(self, id: int): + """ + Retrieve the key associated with the given ID. + + Args: + id (int): The ID to look up. + + Returns: + Optional[Hashable]: The associated key, or None if not found. + """ + return self.kh_map.get(id) + + def get_id(self, key: Hashable): + """ + Retrieve the ID associated with the given key. + + Args: + key (Hashable): The key to look up. + + Returns: + Optional[int]: The associated ID, or None if not found. + """ + return self.hk_map.get(key) + + +@dataclass +class SemanticCPUCacheConfig: + """ + Configuration class for SemanticCPUCache. + + Attributes: + cache (Any): The cache object to use. + Default: Cache(policy="lru", size=1000). + encoder (Any): The encoder object for embedding queries. + Default: SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + encoder_dim (int): The encoder dimension. + Default: 384. The size of `all-MiniLM-L6-v2` embeddings. + index (Any): The index object for similarity search. + Default: faiss.IndexIDMap(faiss.IndexFlatL2(encoder_dim)) + threshold (float): The similarity threshold for considering a match. + Default: 0.25 + key_mapper (Any): The key mapper object for managing key-ID mappings. + default: KeyMapper() + """ + + cache: Any = Cache(policy="lru", size=1000) + encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + encoder_dim: int = 384 + index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(encoder_dim)) + threshold: float = 0.25 + key_mapper: Any = KeyMapper() + + +class SemanticCPUCache: + """ + Semantic cache implementation. + """ + + def __init__(self, config: SemanticCPUCacheConfig): + """ + Initialize the SemanticCPUCache with the given configuration. + + Args: + config (SemanticCPUCacheConfig): The configuration object. + """ + self.encoder = config.encoder + self.index = config.index + self.cache = config.cache + self.key_map = config.key_mapper + self.threshold = config.threshold + + def get(self, key: Hashable, default: Any = None) -> Any: + """ + Retrieve a value from the cache based on the given key. + + First, a similarity search is performed. If a similar key is found + within the threshold, its associated value is returned. + Otherwise, the default value is returned. + + Args: + key (Hashable): The key to look up. + default (Any, optional): The default value to return if no match is found. Defaults to None. + + Returns: + Any: The retrieved value or the default value. + """ + if self.index.ntotal < 1: + return default + + key_search = np.asarray([self.encoder.encode(key)]) + # The vector index returns two values, distance to the most similar + # embedding (1 indicates we only need top 1 similar result), and + # its numerical index. + dist, ind = self.index.search(key_search, 1) + + # If the distance between vectors above the set threshold, i.e. + # the most similar embedding is too far from the current prompt + # embedding, this considered as cache miss and we return the `default`. + if dist[0][0] > self.threshold: + return default + + # To retrieve the cache hit from the cache store, we need to retrieve + # the corresponding prompt from the key_map store, given its index. + key_str = self.key_map.get_key(ind[0][0]) + + return self.cache.get(key=key_str, default=default) + + def set(self, key: Hashable, value: Any) -> Optional[str]: + """ + Set a key-value pair in the cache. + + This method adds the key to the key mapper, encodes the key, + adds the encoded key to the index, and sets the value in the cache. + + Args: + key (Hashable): The key to set. + value (Any): The value to associate with the key. + + Returns: + Optional[str]: The result of setting the value in the cache. + + Raises: + AssertionError: If the key could not be added to the key mapper. + """ + id = self.key_map.add_key(key) + assert id is not None, "Adding key to the key map failed, returned id is None." + self.index.add_with_ids( + np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) + ) + # Adding a new entry into the cache can evict an old entry, according + # to the policy in-use. We need to make sure we evict the same entry + # from the vector index, stored in `self.index`. + evicted_key = self.cache.set(key, value) + self._handle_evicted_key(evicted_key=evicted_key) + + return None + + def _handle_evicted_key(self, evicted_key: Optional[Hashable]) -> None: + """ + Handle the eviction of a key from the cache. + + This method is called when a key is evicted from the cache. It removes + the evicted key from the key_map and its corresponding + vector embedding from the index. + + Args: + evicted_key (Optional[Hashable]): The key that was evicted from the + cache. + """ + if evicted_key: + evicted_id = self.key_map.remove_key(evicted_key) + self.index.remove_ids(np.array([evicted_id])) + return None diff --git a/Conceptual_Guide/README.md b/Conceptual_Guide/README.md index 115f96e9..d0a44b5c 100644 --- a/Conceptual_Guide/README.md +++ b/Conceptual_Guide/README.md @@ -40,3 +40,4 @@ Conceptual guides have been designed as an onboarding experience to Triton Infer * [Part 5: Building Model Ensembles](./Part_5-Model_Ensembles/): Models are rarely used standalone. This guide will cover "how to build a deep learning inference pipeline?" * [Part 6: Using the BLS API to build complex pipelines](Part_6-building_complex_pipelines/): Often times there are scenarios where the pipeline requires control flows. Learn how to work with complex pipelines with models deployed on different backends. * [Part 7: Iterative Scheduling Tutorial](./Part_7-iterative_scheduling): Shows how to use the Triton Iterative Scheduler with a GPT2 model using HuggingFace Transformers. +* [Part 8: Semantic Caching](./Part_8-semantic_caching/): Shows benefits of adding semantic caching to you LLM-based workflow.