diff --git a/Quick_Deploy/Customization/vLLM/.gitignore b/Quick_Deploy/Customization/vLLM/.gitignore new file mode 100644 index 00000000..82559cc4 --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/.gitignore @@ -0,0 +1,6 @@ +Miniconda* +miniconda +model_repository/vllm/vllm_env.tar.gz +model_repository/vllm/triton_python_backend_stub +python_backend +results.txt diff --git a/Quick_Deploy/Customization/vLLM/README.md b/Quick_Deploy/Customization/vLLM/README.md new file mode 100644 index 00000000..739c4a19 --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/README.md @@ -0,0 +1,252 @@ + + + +# Deploying a vLLM model in Triton + +The following tutorial demonstrates how to deploy a simple +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model on +Triton Inference Server using the Triton's +[Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) +[vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) +backend. + +*NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations). + + + +## Step 1: Prepare Triton vllm_backend +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) has been released +as [xx.yy-vllm-python-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags) in Triton NGC, +where is the version of Triton vllm_backend, such as `23.10-vllm-python-py3`. + +You can just get the vllm_backend docker image above. + + +## Step 2: Prepare your model repository + +To use Triton vllm_backend, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using Triton vllm_backend is +included with this demo as `model_repository` directory. + +The model repository should look like this: +``` +model_repository/ +└── vllm_model + ├── 1 + └── config.pbtxt +``` + +The configuration of engineArgs is in config.pbtxt: + +``` +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.5" + } +} + +``` + +This file can be modified to provide further settings to the vLLM engine. See vLLM +[AsyncEngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L165) +and +[EngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L11) +for supported key-value pairs. Inflight batching and paged attention is handled +by the vLLM engine. + +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`config.pbtxt`](model_repository/vllm_model/config.pbtxt). + +*Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. +This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. +You can tweak this behavior using fields like `gpu_memory_utilization` and other settings +in [`config.pbtxt`](model_repository/vllm_model/config.pbtxt). + +Read through the documentation in [`model.py`](model.py) to understand how +to configure this sample for your use-case. + +## Step 3: Launch Triton Inference Server + +Once you have the model repository setup, it is time to launch the triton server. +Starting with 23.10 release, a dedicated container with vLLM pre-installed +is available on [NGC.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) +To use this container to launch Triton, you can use the docker command below. +``` +docker run -idt -p 8000:8000 -p 8001:8001 -p 8002:8002 --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v ${PWD}:/model_repository nvcr.io/nvidia/tritonserver:-vllm-python-py3 tritonserver /bin/sh +``` +Throughout the tutorial, \ is the version of Triton +that you want to use. Please note, that Triton's vLLM +container was first published in 23.10 release, so any prior version +will not work. + +Now, you can get the `CONTAINER ID`, and use the command to enter the container like this: +``` +docker exec -it CONTAINER_ID /bin/bash +``` + +Now, you can see the model repository in the container like this: +``` +model_repository/ +└── vllm_model + ├── 1 + └── config.pbtxt +``` + +And, you can see the vllm_backend in the container which path is `/opt/tritonserver/backends/vllm/model.py`, +you need to use [`model.py`](model.py) to replace the model.py in `/opt/tritonserver/backends/vllm`. + +If you want to get a new docker image, you can commit it like this: +``` +docker commit CONTAINER_ID nvcr.io/nvidia/tritonserver:-vllm-new-python-py3 +``` + +You need to start the Triton with command like this: +``` +/opt/tritonserver/bin/tritonserver --model-store=/model_repository +``` + +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. + +``` +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +## Step 4: Use a Triton Client to Send Your First Inference Request + +In this tutorial, we will show how to send an inference request to the +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model in 2 ways: + +* [Using the generate endpoint](#using-generate-endpoint) +* [Using the gRPC asyncio client](#using-grpc-asyncio-client) + +### Using the Generate Endpoint +After you start Triton with the sample model_repository, +you can quickly run your first inference request with the +[generate](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md) +endpoint. + +Start Triton's SDK container with the following command: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` + +Now, let's send an inference request: +``` +curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output":"What is Triton Inference Server?\n\nTriton Inference Server is a server that is used by many"} +``` + +### Using the gRPC Asyncio Client +Now, we will see how to run the client within Triton's SDK container +to issue multiple async requests using the +[gRPC asyncio client](https://github.com/triton-inference-server/client/blob/main/src/python/library/tritonclient/grpc/aio/__init__.py) +library. + +This method requires a +[client.py](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +script and a set of +[prompts](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt), +which are provided in the +[samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. + +Use the following command to download `client.py` and `prompts.txt` to your +current directory: +``` +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client.py +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt +``` + +Now, we are ready to start Triton's SDK container: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` + +Within the container, run +[`client.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +with: +``` +python3 client.py +``` + +The client reads prompts from the +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt) +file, sends them to Triton server for +inference, and stores the results into a file named `results.txt` by default. + +The output of the client should look like below: + +``` +Loading inputs from `prompts.txt`... +Storing results into `results.txt`... +PASS: vLLM example +``` + +You can inspect the contents of the `results.txt` for the response +from the server. The `--iterations` flag can be used with the client +to increase the load on the server by looping through the list of +provided prompts in +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt). + +When you run the client in verbose mode with the `--verbose` flag, +the client will print more details about the request/response transactions. + +## Limitations + +- We use decoupled streaming protocol even if there is exactly 1 response for each request. +- The asyncio implementation is exposed to model.py. +- Does not support providing specific subset of GPUs to be used. +- If you are running multiple instances of Triton server with +a Python-based vLLM backend, you need to specify a different +`shm-region-prefix-name` for each server. See +[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) +for more information. diff --git a/Quick_Deploy/Customization/vLLM/model.py b/Quick_Deploy/Customization/vLLM/model.py new file mode 100644 index 00000000..fd5dd727 --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/model.py @@ -0,0 +1,281 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import asyncio +import json +import threading + +import numpy as np +import triton_python_backend_utils as pb_utils +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.utils import random_uuid + +class TritonPythonModel: + def initialize(self, args): + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + + # assert are in decoupled mode. Currently, Triton needs to use + # decoupled policy for asynchronously forwarding requests to + # vLLM engine. + self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( + self.model_config + ) + assert ( + self.using_decoupled + ), "vLLM Triton backend must be configured to use decoupled model transaction policy" + + self.model_name = args["model_name"] + assert ( + self.model_name + ), "Parameter of [name] must be configured, and can not be empty in config.pbtxt" + + # Create an AsyncLLMEngine from the config from JSON + self.llm_engine = AsyncLLMEngine.from_engine_args( + AsyncEngineArgs(**self.handle_initializing_config()) + ) + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() + + def handle_initializing_config(self): + model_params = self.model_config.get("parameters", {}) + model_engine_args = {} + for key, value in model_params.items(): + model_engine_args[key] = value['string_value'] + + bool_keys = ["trust_remote_code", "use_np_weights", "use_dummy_weights", + "worker_use_ray", "disable_log_stats"] + for k in bool_keys: + if k in model_engine_args: + model_engine_args[k] = bool(model_engine_args[k]) + + float_keys = ["gpu_memory_utilization"] + for k in float_keys: + if k in model_engine_args: + model_engine_args[k] = float(model_engine_args[k]) + + int_keys = ["seed", "pipeline_parallel_size", "tensor_parallel_size", "block_size", + "swap_space", "max_num_batched_tokens", "max_num_seqs"] + for k in int_keys: + if k in model_engine_args: + model_engine_args[k] = int(model_engine_args[k]) + + # Check necessary parameter configuration in model config + model_param = model_engine_args["model"] + assert ( + model_param + ), "Parameter of [model] must be configured, and can not be empty in config.pbtxt" + + self.logger.log_info(f"Initialize engineArgs: {model_engine_args}") + return model_engine_args + + def create_task(self, coro): + """ + Creates a task on the engine's event loop which is running on a separate thread. + """ + assert ( + self._shutdown_event.is_set() is False + ), "Cannot create tasks after shutdown has been requested" + + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def engine_loop(self, loop): + """ + Runs the engine's event loop on a separate thread. + """ + asyncio.set_event_loop(loop) + self._loop.run_until_complete(self.await_shutdown()) + + async def await_shutdown(self): + """ + Primary coroutine running on the engine event loop. This coroutine is responsible for + keeping the engine alive until a shutdown is requested. + """ + # first await the shutdown signal + while self._shutdown_event.is_set() is False: + await asyncio.sleep(5) + + # Wait for the ongoing_requests + while self.ongoing_request_count > 0: + self.logger.log_info( + "[vllm] Awaiting remaining {} requests".format( + self.ongoing_request_count + ) + ) + await asyncio.sleep(5) + + for task in asyncio.all_tasks(loop=self._loop): + if task is not asyncio.current_task(): + task.cancel() + + self.logger.log_info("[vllm] Shutdown complete") + + def get_sampling_params_dict(self, params_json): + """ + This functions parses the dictionary values into their + expected format. + """ + + params_dict = json.loads(params_json) + + # Special parsing for the supported sampling parameters + bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] + for k in bool_keys: + if k in params_dict: + params_dict[k] = bool(params_dict[k]) + + float_keys = [ + "frequency_penalty", + "length_penalty", + "presence_penalty", + "temperature", + "top_p", + ] + for k in float_keys: + if k in params_dict: + params_dict[k] = float(params_dict[k]) + + int_keys = ["best_of", "max_tokens", "n", "top_k"] + for k in int_keys: + if k in params_dict: + params_dict[k] = int(params_dict[k]) + + return params_dict + + def create_response(self, vllm_output): + """ + Parses the output from the vLLM engine into Triton + response. + """ + prompt = vllm_output.prompt + text_outputs = [ + (prompt + output.text).encode("utf-8") for output in vllm_output.outputs + ] + triton_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(text_outputs, dtype=self.output_dtype) + ) + return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + + async def generate(self, request): + """ + Forwards single request to LLM engine and returns responses. + """ + response_sender = request.get_response_sender() + self.ongoing_request_count += 1 + try: + request_id = random_uuid() + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # Request parameters are not yet supported via + # BLS. Provide an optional mechanism to receive serialized + # parameters as an input tensor until support is added + + parameters_input_tensor = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if parameters_input_tensor: + parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") + else: + parameters = request.parameters() + + sampling_params_dict = self.get_sampling_params_dict(parameters) + sampling_params = SamplingParams(**sampling_params_dict) + + last_output = None + async for output in self.llm_engine.generate( + prompt, sampling_params, request_id + ): + if stream: + response_sender.send(self.create_response(output)) + else: + last_output = output + + if not stream: + response_sender.send(self.create_response(last_output)) + + except Exception as e: + self.logger.log_info(f"[vllm] Error generating stream: {e}") + error = pb_utils.TritonError(f"Error generating stream: {e}") + triton_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) + ) + response = pb_utils.InferenceResponse( + output_tensors=[triton_output_tensor], error=error + ) + response_sender.send(response) + raise e + finally: + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.ongoing_request_count -= 1 + + def execute(self, requests): + """ + Triton core issues requests to the backend via this method. + + When this method returns, new requests can be issued to the backend. Blocking + this function would prevent the backend from pulling additional requests from + Triton into the vLLM engine. This can be done if the kv cache within vLLM engine + is too loaded. + We are pushing all the requests on vllm and let it handle the full traffic. + """ + for request in requests: + self.create_task(self.generate(request)) + return None + + def finalize(self): + """ + Triton virtual method; called when the model is unloaded. + """ + self.logger.log_info("[vllm] Issuing finalize to vllm backend") + self._shutdown_event.set() + if self._loop_thread is not None: + self._loop_thread.join() + self._loop_thread = None diff --git a/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt new file mode 100644 index 00000000..377142af --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt @@ -0,0 +1,98 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "vllm" +backend: "python" + +# Disabling batching in Triton, let vLLM handle the batching on its own. +max_batch_size: 0 + +# We need to use decoupled transaction policy for saturating +# vLLM engine for max throughtput. +# TODO [DLIS:5233]: Allow asynchronous execution to lift this +# restriction for cases there is exactly a single response to +# a single request. +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "sampling_parameters" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +# The usage of device is deferred to the vLLM engine +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] + +# The configuration of engineArgs +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.5" + } +} diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index ee48f2af..54f39327 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -214,4 +214,4 @@ the client will print more details about the request/response transactions. a Python-based vLLM backend, you need to specify a different `shm-region-prefix-name` for each server. See [here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) -for more information. +for more information. \ No newline at end of file