diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/README.md b/Conceptual_Guide/Part_7-iterative_scheduling/README.md new file mode 100644 index 00000000..8e0c9ea2 --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/README.md @@ -0,0 +1,135 @@ + + +# Deploying a GPT-2 Model using Python Backend and Iterative Scheduling + +In this tutorial, we will deploy a GPT-2 model using the Python backend and +demonstrate the +[iterative scheduling](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#iterative-sequences) +feature. + +## Prerequisites + +Before getting started with this tutorial, make sure you're familiar +with the following concepts: + +* [Triton-Server Quick Start](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/getting_started/quickstart.html) +* [Python Backend](https://github.com/triton-inference-server/python_backend) + +## Iterative Scheduling + +Iterative scheduling is a technique that allows the Triton Inference Server to +schedule the same request multiple times with the same input. This is useful for +models that have an auto-regressive loop. Iterative scheduling enables Triton +Server to implement inflight batching for your models and gives you the ability +to combine new sequences as they are arriving with inflight sequences. + +## Tutorial Overview + +In this tutorial we deploy two models: + +* simple-gpt2: This model receives a batch of requests and proceeds to the next +batch only when it is done generating tokens for the current batch. + +* iterative-gpt2: This model uses iterative scheduling to process +new sequences in a batch even when it is still generating tokens for the +previous sequences + +### Demo + +[![asciicast](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx.svg)](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx) + +### Step 1: Prepare the Server Environment + +* First, run the Triton Inference Server Container: + +``` +# Replace yy.mm with year and month of release. Please use 24.04 release upward. +docker run --gpus=all --name iterative-scheduling -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v ${PWD}:/workspace/ -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:yy.mm-py3 bash +``` + +* Next, install all the dependencies required by the models running in the +python backend and login with your [huggingface token](https://huggingface.co/settings/tokens) +(Account on [HuggingFace](https://huggingface.co/) is required). + +``` +pip install transformers[torch] +``` + +> [!NOTE] +> Optional: If you want to avoid installing the dependencies each time you run the +> container, you can run `docker commit iterative-scheduling iterative-scheduling-image` to save the container +> and use that for subsequent runs. + +Then, start the server: + +``` +tritonserver --model-repository=/models +``` + +### Step 2: Install the client side dependencies + +In another terminal install the client dependencies: + +``` +pip3 install tritonclient[grpc] +pip3 install tqdm +``` + +### Step 3: Run the client + +The simple-gpt2 model doesn't use iterative scheduling and will proceed to the +next batch only when it is done generating tokens for the current batch. + +Run the following command to start the client: + +``` +python3 client/client.py --model simple-gpt2 +``` + +As you can see, the tokens for one request are processed first before proceeding +to the next request. + +Run `Ctrl+C` to stop the client. + + +The iterative scheduler is able to incorporate new requests as they are arriving +in the server. + +Run the following command to start the client: +``` +python3 client/client.py --model iterative-gpt2 +``` + +As you can see, the tokens for both prompts are getting generated simultaneously. + +## Next Steps + +We plan to integrate KV-Cache with these models for better performance. Currently, +the main goal of tutorial is to demonstrate how to use iterative scheduling with +Python backend. diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/client/client.py b/Conceptual_Guide/Part_7-iterative_scheduling/client/client.py new file mode 100644 index 00000000..476d43ea --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/client/client.py @@ -0,0 +1,114 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import threading +import time +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +from print_utils import Display + + +def client1_callback(display, event, result, error): + if error: + raise error + + display.update_top() + if result.get_response().parameters.get("triton_final_response").bool_param: + event.set() + + +def client2_callback(display, event, result, error): + if error: + raise error + + display.update_bottom() + if result.get_response().parameters.get("triton_final_response").bool_param: + event.set() + + +def run_inferences(url, model_name, display, max_tokens): + # Create clients + client1 = grpcclient.InferenceServerClient(url) + client2 = grpcclient.InferenceServerClient(url) + + inputs0 = [] + prompt1 = "Programming in C++ is like" + inputs0.append(grpcclient.InferInput("text_input", [1, 1], "BYTES")) + inputs0[0].set_data_from_numpy(np.array([[prompt1]], dtype=np.object_)) + + prompt2 = "Programming in Assembly is like" + inputs1 = [] + inputs1.append(grpcclient.InferInput("text_input", [1, 1], "BYTES")) + inputs1[0].set_data_from_numpy(np.array([[prompt2]], dtype=np.object_)) + + event1 = threading.Event() + event2 = threading.Event() + client1.start_stream(callback=partial(partial(client1_callback, display), event1)) + client2.start_stream(callback=partial(partial(client2_callback, display), event2)) + + while True: + # Reset the events + event1.clear() + event2.clear() + + # Setup the display initially with the prompts + display.clear() + parameters = {"ignore_eos": True, "max_tokens": max_tokens} + + client1.async_stream_infer( + model_name=model_name, + inputs=inputs0, + enable_empty_final_response=True, + parameters=parameters, + ) + + # Add a small delay so that the two requests are not sent at the same + # time + time.sleep(0.05) + client2.async_stream_infer( + model_name=model_name, + inputs=inputs1, + enable_empty_final_response=True, + parameters=parameters, + ) + + event1.wait() + event2.wait() + time.sleep(2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="localhost:8001") + parser.add_argument("--model", type=str, default="simple-gpt2") + parser.add_argument("--max-tokens", type=int, default=128) + args = parser.parse_args() + display = Display(args.max_tokens) + + run_inferences(args.url, args.model, display, args.max_tokens) diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/client/print_utils.py b/Conceptual_Guide/Part_7-iterative_scheduling/client/print_utils.py new file mode 100644 index 00000000..daf0aac1 --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/client/print_utils.py @@ -0,0 +1,46 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from tqdm import tqdm + + +class Display: + def __init__(self, max_tokens) -> None: + self._top = tqdm(position=0, total=max_tokens, miniters=1) + self._bottom = tqdm(position=1, total=max_tokens, miniters=1) + self._max_tokens = max_tokens + + def update_top(self): + self._top.update(1) + self._top.refresh() + + def update_bottom(self): + self._bottom.update(1) + self._bottom.refresh() + + def clear(self): + self._top.reset() + self._bottom.reset() diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/input_data.json b/Conceptual_Guide/Part_7-iterative_scheduling/input_data.json new file mode 100644 index 00000000..36cd6bbf --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/input_data.json @@ -0,0 +1,8 @@ +{ + "data": + [ + { + "input": ["machine learning is"] + } + ] + } \ No newline at end of file diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/1/model.py b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/1/model.py new file mode 100644 index 00000000..b86dc205 --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/1/model.py @@ -0,0 +1,216 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from transformers import GPT2LMHeadModel, GPT2Tokenizer + + +class State: + def __init__(self): + self.prompt_tokens_len = 0 + self.tokens = [] + self.max_tokens = 0 + self.ignore_eos = False + + +class TritonPythonModel: + def initialize(self, args): + self.state = {} + device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu" + device_id = args["model_instance_device_id"] + self.device = f"{device}:{device_id}" + + # Load the GPT-2 model + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @staticmethod + def auto_complete_config(config): + inputs = [ + { + "name": "text_input", + "data_type": "TYPE_STRING", + "dims": [1], + } + ] + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [1]}] + + for input in inputs: + config.add_input(input) + for output in outputs: + config.add_output(output) + + # Enable decoupled mode + transaction_policy = {"decoupled": True} + config.set_model_transaction_policy(transaction_policy) + config.set_max_batch_size(8) + + return config + + def create_batch(self, requests): + """ + Gathers input tensors from the requests and returns processed input tensors. + + Args: + requests (list): A list of requests containing input tensors. + + Returns: + input_ids (torch.Tensor): A tensor containing the processed input IDs. + attention_mask (torch.Tensor): A tensor containing the attention mask. + """ + + input_ids = [] + for request in requests: + input_tensor = str( + pb_utils.get_input_tensor_by_name(request, "text_input") + .as_numpy() + .item(), + encoding="utf-8", + ) + correlation_id = ( + pb_utils.get_input_tensor_by_name(request, "correlation_id") + .as_numpy() + .item() + ) + start = ( + pb_utils.get_input_tensor_by_name(request, "start").as_numpy().item() + ) + if start: + state = State() + state.tokens = self.tokenizer( + input_tensor, return_tensors="pt", padding=True + )["input_ids"][0].to(self.device) + state.prompt_tokens_len = len(state.tokens) + + # Store the parameters + parameters = json.loads(request.parameters()) + state.ignore_eos = parameters["ignore_eos"] + state.max_tokens = parameters["max_tokens"] + + self.state[correlation_id] = state + + input_ids.append(self.state[correlation_id].tokens) + + # Find the max sequence length + max_len = max([len(x) for x in input_ids]) + + # Pad the input tensors. + input_ids_torch = torch.cat( + [ + torch.cat( + [ + torch.tensor( + [self.tokenizer.eos_token_id] * (max_len - len(x)), + device=self.device, + ) + ] + + [x] + ).unsqueeze(0) + for x in input_ids + ] + ) + attention_mask = torch.cat( + [ + torch.cat( + [ + torch.tensor([0] * (max_len - x.numel())), + torch.tensor([1] * x.numel()), + ] + ).unsqueeze(0) + for x in input_ids + ] + ) + return input_ids_torch.long(), attention_mask.long().to(self.device) + + def send_responses(self, requests, outputs): + """ + Scatter method for processing requests and sending responses. + + Args: + requests (list): List of Triton InferenceRequest objects. + outputs (list): List of output tensors generated by the model. + + Returns: + None + """ + for i, request in enumerate(requests): + correlation_id = ( + pb_utils.get_input_tensor_by_name(request, "correlation_id") + .as_numpy() + .item() + ) + response_sender = request.get_response_sender() + # Convert scalar to a one dimensional tensor + generated_token = outputs[i][-1].reshape(1) + + ignore_eos = self.state[correlation_id].ignore_eos + + # Maximum generated token length + max_tokens = ( + self.state[correlation_id].max_tokens + + self.state[correlation_id].prompt_tokens_len + ) + + self.state[correlation_id].tokens = torch.cat( + [self.state[correlation_id].tokens, generated_token] + ) + if ( + generated_token.item() == self.tokenizer.eos_token_id and not ignore_eos + ) or len(self.state[correlation_id].tokens) >= max_tokens: + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_ALL) + del self.state[correlation_id] + else: + request.set_release_flags( + pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE + ) + flags = 0 + + output_decoded = self.tokenizer.decode(generated_token.cpu().item()) + response = pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "text_output", np.array([output_decoded], dtype=np.object_) + ) + ] + ) + response_sender.send(response, flags=flags) + + def execute(self, requests): + pb_utils.Logger.log_verbose(f"Processing {len(requests)} request(s).") + input_ids, attention_mask = self.create_batch(requests) + + outputs = self.model.generate( + input_ids, + max_new_tokens=1, + pad_token_id=self.tokenizer.eos_token_id, + attention_mask=attention_mask, + ) + self.send_responses(requests, outputs) diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/config.pbtxt b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/config.pbtxt new file mode 100644 index 00000000..ae10d343 --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/iterative-gpt2/config.pbtxt @@ -0,0 +1,66 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +sequence_batching { + iterative_sequence: true + control_input: [{ + name: "correlation_id" + control [ + { + kind: CONTROL_SEQUENCE_CORRID + data_type: TYPE_UINT64 + } + ] + }, + { + name: "start" + control [ + { + kind: CONTROL_SEQUENCE_START + fp32_false_true: [ 0, 1 ] + } + ] + }, + { + name: "end" + control [ + { + kind: CONTROL_SEQUENCE_END + fp32_false_true: [ 0, 1 ] + } + ]} + ] + oldest {} + max_sequence_idle_microseconds: 400000000 +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/1/model.py b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/1/model.py new file mode 100644 index 00000000..323aaa73 --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/1/model.py @@ -0,0 +1,199 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from transformers import GPT2LMHeadModel, GPT2Tokenizer + + +class State: + def __init__(self): + self.prompt_tokens_len = 0 + self.tokens = [] + self.max_tokens = 0 + self.ignore_eos = False + + +class TritonPythonModel: + def initialize(self, args): + self.state = {} + device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu" + device_id = args["model_instance_device_id"] + self.device = f"{device}:{device_id}" + + # Load the GPT-2 model + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @staticmethod + def auto_complete_config(config): + inputs = [ + { + "name": "text_input", + "data_type": "TYPE_STRING", + "dims": [1], + } + ] + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [1]}] + + for input in inputs: + config.add_input(input) + for output in outputs: + config.add_output(output) + + transaction_policy = {"decoupled": True} + config.set_dynamic_batching() + config.set_max_batch_size(8) + config.set_model_transaction_policy(transaction_policy) + + return config + + def init_state(self, states, requests): + """ + Initializes the state for each request. + + Args: + states (list): A list to store the state for each request. + requests (list): A list of requests. + + Returns: + None + """ + for request in requests: + input_tensor = str( + pb_utils.get_input_tensor_by_name(request, "text_input") + .as_numpy() + .item(), + encoding="utf-8", + ) + state = State() + + parameters = json.loads(request.parameters()) + state.ignore_eos = parameters["ignore_eos"] + state.max_tokens = parameters["max_tokens"] + state.tokens = self.tokenizer( + input_tensor, return_tensors="pt", padding=True + )["input_ids"][0].to(self.device) + state.prompt_tokens_len = len(state.tokens) + + states.append(state) + + def create_batch(self, states): + # Find the max sequence length + max_len = max([len(x.tokens) for x in states]) + + # Pad the input tensors. + input_ids_torch = torch.cat( + [ + torch.cat( + [ + torch.tensor( + [self.tokenizer.eos_token_id] * (max_len - len(x.tokens)), + device=self.device, + ) + ] + + [x.tokens] + ).unsqueeze(0) + for x in states + ] + ) + attention_mask = torch.cat( + [ + torch.cat( + [ + torch.tensor([0] * (max_len - x.tokens.numel())), + torch.tensor([1] * x.tokens.numel()), + ] + ).unsqueeze(0) + for x in states + ] + ) + return input_ids_torch.long(), attention_mask.long().to(self.device) + + def send_responses(self, states, requests, outputs): + """ + Sends responses to the requests based on the model outputs. + + Args: + states (list): A list of states for each request. + requests (list): A list of requests. + outputs (torch.Tensor): A tensor containing the model outputs. + + Returns: + list: A list of requests that have not been completed. + list: A list of states for each request that have not been completed. + + """ + updated_request_list = [] + updated_states = [] + + for i, request in enumerate(requests): + response_sender = request.get_response_sender() + # Convert scalar to a one dimensional tensor + generated_token = outputs[i][-1].reshape(1) + + max_tokens = states[i].max_tokens + states[i].prompt_tokens_len + states[i].tokens = torch.cat([states[i].tokens, generated_token]) + + if ( + generated_token.item() == self.tokenizer.eos_token_id + and not states[i].ignore_eos + ) or len(states[i].tokens) >= max_tokens: + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + else: + flags = 0 + updated_states.append(states[i]) + updated_request_list.append(request) + + output_decoded = self.tokenizer.decode(generated_token.cpu().item()) + response = pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "text_output", np.array([output_decoded], dtype=np.object_) + ) + ] + ) + response_sender.send(response, flags=flags) + return updated_request_list, updated_states + + def execute(self, requests): + pb_utils.Logger.log_verbose(f"Processing {len(requests)} request(s).") + + states = [] + self.init_state(states, requests) + while requests: + input_ids, attention_mask = self.create_batch(states) + + outputs = self.model.generate( + input_ids, + max_new_tokens=1, + pad_token_id=self.tokenizer.eos_token_id, + attention_mask=attention_mask, + ) + requests, states = self.send_responses(states, requests, outputs) diff --git a/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/config.pbtxt b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/config.pbtxt new file mode 100644 index 00000000..ced4a09a --- /dev/null +++ b/Conceptual_Guide/Part_7-iterative_scheduling/model_repository/simple-gpt2/config.pbtxt @@ -0,0 +1,34 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file