Skip to content

Add initial integration of iterative scheduling (#88) #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions Conceptual_Guide/Part_7-iterative_scheduling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
<!--
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

# Deploying a GPT-2 Model using Python Backend and Iterative Scheduling

In this tutorial, we will deploy a GPT-2 model using the Python backend and
demonstrate the
[iterative scheduling](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#iterative-sequences)
feature.

## Prerequisites

Before getting started with this tutorial, make sure you're familiar
with the following concepts:

* [Triton-Server Quick Start](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/getting_started/quickstart.html)
* [Python Backend](https://github.com/triton-inference-server/python_backend)

## Iterative Scheduling

Iterative scheduling is a technique that allows the Triton Inference Server to
schedule the same request multiple times with the same input. This is useful for
models that have an auto-regressive loop. Iterative scheduling enables Triton
Server to implement inflight batching for your models and gives you the ability
to combine new sequences as they are arriving with inflight sequences.

## Tutorial Overview

In this tutorial we deploy two models:

* simple-gpt2: This model receives a batch of requests and proceeds to the next
batch only when it is done generating tokens for the current batch.

* iterative-gpt2: This model uses iterative scheduling to process
new sequences in a batch even when it is still generating tokens for the
previous sequences

### Demo

[![asciicast](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx.svg)](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx)

### Step 1: Prepare the Server Environment

* First, run the Triton Inference Server Container:

```
# Replace yy.mm with year and month of release. Please use 24.04 release upward.
docker run --gpus=all --name iterative-scheduling -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v ${PWD}:/workspace/ -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:yy.mm-py3 bash
```

* Next, install all the dependencies required by the models running in the
python backend and login with your [huggingface token](https://huggingface.co/settings/tokens)
(Account on [HuggingFace](https://huggingface.co/) is required).

```
pip install transformers[torch]
```

> [!NOTE]
> Optional: If you want to avoid installing the dependencies each time you run the
> container, you can run `docker commit iterative-scheduling iterative-scheduling-image` to save the container
> and use that for subsequent runs.

Then, start the server:

```
tritonserver --model-repository=/models
```

### Step 2: Install the client side dependencies

In another terminal install the client dependencies:

```
pip3 install tritonclient[grpc]
pip3 install tqdm
```

### Step 3: Run the client

The simple-gpt2 model doesn't use iterative scheduling and will proceed to the
next batch only when it is done generating tokens for the current batch.

Run the following command to start the client:

```
python3 client/client.py --model simple-gpt2
```

As you can see, the tokens for one request are processed first before proceeding
to the next request.

Run `Ctrl+C` to stop the client.


The iterative scheduler is able to incorporate new requests as they are arriving
in the server.

Run the following command to start the client:
```
python3 client/client.py --model iterative-gpt2
```

As you can see, the tokens for both prompts are getting generated simultaneously.

## Next Steps

We plan to integrate KV-Cache with these models for better performance. Currently,
the main goal of tutorial is to demonstrate how to use iterative scheduling with
Python backend.
114 changes: 114 additions & 0 deletions Conceptual_Guide/Part_7-iterative_scheduling/client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import threading
import time
from functools import partial

import numpy as np
import tritonclient.grpc as grpcclient
from print_utils import Display


def client1_callback(display, event, result, error):
if error:
raise error

display.update_top()
if result.get_response().parameters.get("triton_final_response").bool_param:
event.set()


def client2_callback(display, event, result, error):
if error:
raise error

display.update_bottom()
if result.get_response().parameters.get("triton_final_response").bool_param:
event.set()


def run_inferences(url, model_name, display, max_tokens):
# Create clients
client1 = grpcclient.InferenceServerClient(url)
client2 = grpcclient.InferenceServerClient(url)

inputs0 = []
prompt1 = "Programming in C++ is like"
inputs0.append(grpcclient.InferInput("text_input", [1, 1], "BYTES"))
inputs0[0].set_data_from_numpy(np.array([[prompt1]], dtype=np.object_))

prompt2 = "Programming in Assembly is like"
inputs1 = []
inputs1.append(grpcclient.InferInput("text_input", [1, 1], "BYTES"))
inputs1[0].set_data_from_numpy(np.array([[prompt2]], dtype=np.object_))

event1 = threading.Event()
event2 = threading.Event()
client1.start_stream(callback=partial(partial(client1_callback, display), event1))
client2.start_stream(callback=partial(partial(client2_callback, display), event2))

while True:
# Reset the events
event1.clear()
event2.clear()

# Setup the display initially with the prompts
display.clear()
parameters = {"ignore_eos": True, "max_tokens": max_tokens}

client1.async_stream_infer(
model_name=model_name,
inputs=inputs0,
enable_empty_final_response=True,
parameters=parameters,
)

# Add a small delay so that the two requests are not sent at the same
# time
time.sleep(0.05)
client2.async_stream_infer(
model_name=model_name,
inputs=inputs1,
enable_empty_final_response=True,
parameters=parameters,
)

event1.wait()
event2.wait()
time.sleep(2)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="localhost:8001")
parser.add_argument("--model", type=str, default="simple-gpt2")
parser.add_argument("--max-tokens", type=int, default=128)
args = parser.parse_args()
display = Display(args.max_tokens)

run_inferences(args.url, args.model, display, args.max_tokens)
46 changes: 46 additions & 0 deletions Conceptual_Guide/Part_7-iterative_scheduling/client/print_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from tqdm import tqdm


class Display:
def __init__(self, max_tokens) -> None:
self._top = tqdm(position=0, total=max_tokens, miniters=1)
self._bottom = tqdm(position=1, total=max_tokens, miniters=1)
self._max_tokens = max_tokens

def update_top(self):
self._top.update(1)
self._top.refresh()

def update_bottom(self):
self._bottom.update(1)
self._bottom.refresh()

def clear(self):
self._top.reset()
self._bottom.reset()
8 changes: 8 additions & 0 deletions Conceptual_Guide/Part_7-iterative_scheduling/input_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"data":
[
{
"input": ["machine learning is"]
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from transformers import GPT2LMHeadModel, GPT2Tokenizer


class State:
def __init__(self):
self.prompt_tokens_len = 0
self.tokens = []
self.max_tokens = 0
self.ignore_eos = False


class TritonPythonModel:
def initialize(self, args):
self.state = {}
device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu"
device_id = args["model_instance_device_id"]
self.device = f"{device}:{device_id}"

# Load the GPT-2 model
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device)
self.tokenizer.pad_token = self.tokenizer.eos_token

@staticmethod
def auto_complete_config(config):
inputs = [
{
"name": "text_input",
"data_type": "TYPE_STRING",
"dims": [1],
}
]
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [1]}]

for input in inputs:
config.add_input(input)
for output in outputs:
config.add_output(output)

# Enable decoupled mode
transaction_policy = {"decoupled": True}
config.set_model_transaction_policy(transaction_policy)
config.set_max_batch_size(8)

return config

def create_batch(self, requests):
"""
Gathers input tensors from the requests and returns processed input tensors.
Args:
requests (list): A list of requests containing input tensors.
Returns:
input_ids (torch.Tensor): A tensor containing the processed input IDs.
attention_mask (torch.Tensor): A tensor containing the attention mask.
"""

input_ids = []
for request in requests:
input_tensor = str(
pb_utils.get_input_tensor_by_name(request, "text_input")
.as_numpy()
.item(),
encoding="utf-8",
)
correlation_id = (
pb_utils.get_input_tensor_by_name(request, "correlation_id")
.as_numpy()
.item()
)
start = (
pb_utils.get_input_tensor_by_name(request, "start").as_numpy().item()
)
if start:
state = State()
state.tokens = self.tokenizer(
input_tensor, return_tensors="pt", padding=True
)["input_ids"][0].to(self.device)
state.prompt_tokens_len = len(state.tokens)

# Store the parameters
parameters = json.loads(request.parameters())
state.ignore_eos = parameters["ignore_eos"]
state.max_tokens = parameters["max_tokens"]

self.state[correlation_id] = state

input_ids.append(self.state[correlation_id].tokens)

# Find the max sequence length
max_len = max([len(x) for x in input_ids])

# Pad the input tensors.
input_ids_torch = torch.cat(
[
torch.cat(
[
torch.tensor(
[self.tokenizer.eos_token_id] * (max_len - len(x)),
device=self.device,
)
]
+ [x]
).unsqueeze(0)
for x in input_ids
]
)
attention_mask = torch.cat(
[
torch.cat(
[
torch.tensor([0] * (max_len - x.numel())),
torch.tensor([1] * x.numel()),
]
).unsqueeze(0)
for x in input_ids
]
)
return input_ids_torch.long(), attention_mask.long().to(self.device)

def send_responses(self, requests, outputs):
"""
Scatter method for processing requests and sending responses.
Args:
requests (list): List of Triton InferenceRequest objects.
outputs (list): List of output tensors generated by the model.
Returns:
None
"""
for i, request in enumerate(requests):
correlation_id = (
pb_utils.get_input_tensor_by_name(request, "correlation_id")
.as_numpy()
.item()
)
response_sender = request.get_response_sender()
# Convert scalar to a one dimensional tensor
generated_token = outputs[i][-1].reshape(1)

ignore_eos = self.state[correlation_id].ignore_eos

# Maximum generated token length
max_tokens = (
self.state[correlation_id].max_tokens
+ self.state[correlation_id].prompt_tokens_len
)

self.state[correlation_id].tokens = torch.cat(
[self.state[correlation_id].tokens, generated_token]
)
if (
generated_token.item() == self.tokenizer.eos_token_id and not ignore_eos
) or len(self.state[correlation_id].tokens) >= max_tokens:
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_ALL)
del self.state[correlation_id]
else:
request.set_release_flags(
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
)
flags = 0

output_decoded = self.tokenizer.decode(generated_token.cpu().item())
response = pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"text_output", np.array([output_decoded], dtype=np.object_)
)
]
)
response_sender.send(response, flags=flags)

def execute(self, requests):
pb_utils.Logger.log_verbose(f"Processing {len(requests)} request(s).")
input_ids, attention_mask = self.create_batch(requests)

outputs = self.model.generate(
input_ids,
max_new_tokens=1,
pad_token_id=self.tokenizer.eos_token_id,
attention_mask=attention_mask,
)
self.send_responses(requests, outputs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

backend: "python"
sequence_batching {
iterative_sequence: true
control_input: [{
name: "correlation_id"
control [
{
kind: CONTROL_SEQUENCE_CORRID
data_type: TYPE_UINT64
}
]
},
{
name: "start"
control [
{
kind: CONTROL_SEQUENCE_START
fp32_false_true: [ 0, 1 ]
}
]
},
{
name: "end"
control [
{
kind: CONTROL_SEQUENCE_END
fp32_false_true: [ 0, 1 ]
}
]}
]
oldest {}
max_sequence_idle_microseconds: 400000000
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from transformers import GPT2LMHeadModel, GPT2Tokenizer


class State:
def __init__(self):
self.prompt_tokens_len = 0
self.tokens = []
self.max_tokens = 0
self.ignore_eos = False


class TritonPythonModel:
def initialize(self, args):
self.state = {}
device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu"
device_id = args["model_instance_device_id"]
self.device = f"{device}:{device_id}"

# Load the GPT-2 model
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device)
self.tokenizer.pad_token = self.tokenizer.eos_token

@staticmethod
def auto_complete_config(config):
inputs = [
{
"name": "text_input",
"data_type": "TYPE_STRING",
"dims": [1],
}
]
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [1]}]

for input in inputs:
config.add_input(input)
for output in outputs:
config.add_output(output)

transaction_policy = {"decoupled": True}
config.set_dynamic_batching()
config.set_max_batch_size(8)
config.set_model_transaction_policy(transaction_policy)

return config

def init_state(self, states, requests):
"""
Initializes the state for each request.
Args:
states (list): A list to store the state for each request.
requests (list): A list of requests.
Returns:
None
"""
for request in requests:
input_tensor = str(
pb_utils.get_input_tensor_by_name(request, "text_input")
.as_numpy()
.item(),
encoding="utf-8",
)
state = State()

parameters = json.loads(request.parameters())
state.ignore_eos = parameters["ignore_eos"]
state.max_tokens = parameters["max_tokens"]
state.tokens = self.tokenizer(
input_tensor, return_tensors="pt", padding=True
)["input_ids"][0].to(self.device)
state.prompt_tokens_len = len(state.tokens)

states.append(state)

def create_batch(self, states):
# Find the max sequence length
max_len = max([len(x.tokens) for x in states])

# Pad the input tensors.
input_ids_torch = torch.cat(
[
torch.cat(
[
torch.tensor(
[self.tokenizer.eos_token_id] * (max_len - len(x.tokens)),
device=self.device,
)
]
+ [x.tokens]
).unsqueeze(0)
for x in states
]
)
attention_mask = torch.cat(
[
torch.cat(
[
torch.tensor([0] * (max_len - x.tokens.numel())),
torch.tensor([1] * x.tokens.numel()),
]
).unsqueeze(0)
for x in states
]
)
return input_ids_torch.long(), attention_mask.long().to(self.device)

def send_responses(self, states, requests, outputs):
"""
Sends responses to the requests based on the model outputs.
Args:
states (list): A list of states for each request.
requests (list): A list of requests.
outputs (torch.Tensor): A tensor containing the model outputs.
Returns:
list: A list of requests that have not been completed.
list: A list of states for each request that have not been completed.
"""
updated_request_list = []
updated_states = []

for i, request in enumerate(requests):
response_sender = request.get_response_sender()
# Convert scalar to a one dimensional tensor
generated_token = outputs[i][-1].reshape(1)

max_tokens = states[i].max_tokens + states[i].prompt_tokens_len
states[i].tokens = torch.cat([states[i].tokens, generated_token])

if (
generated_token.item() == self.tokenizer.eos_token_id
and not states[i].ignore_eos
) or len(states[i].tokens) >= max_tokens:
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
else:
flags = 0
updated_states.append(states[i])
updated_request_list.append(request)

output_decoded = self.tokenizer.decode(generated_token.cpu().item())
response = pb_utils.InferenceResponse(
output_tensors=[
pb_utils.Tensor(
"text_output", np.array([output_decoded], dtype=np.object_)
)
]
)
response_sender.send(response, flags=flags)
return updated_request_list, updated_states

def execute(self, requests):
pb_utils.Logger.log_verbose(f"Processing {len(requests)} request(s).")

states = []
self.init_state(states, requests)
while requests:
input_ids, attention_mask = self.create_batch(states)

outputs = self.model.generate(
input_ids,
max_new_tokens=1,
pad_token_id=self.tokenizer.eos_token_id,
attention_mask=attention_mask,
)
requests, states = self.send_responses(states, requests, outputs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

backend: "python"

instance_group [
{
count: 1
kind: KIND_GPU
}
]