Skip to content

Commit c156dad

Browse files
authoredApr 24, 2024
Add initial integration of iterative scheduling (#88) (#89)
1 parent cb2ca25 commit c156dad

File tree

8 files changed

+818
-0
lines changed

8 files changed

+818
-0
lines changed
 
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
<!--
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
-->
28+
29+
# Deploying a GPT-2 Model using Python Backend and Iterative Scheduling
30+
31+
In this tutorial, we will deploy a GPT-2 model using the Python backend and
32+
demonstrate the
33+
[iterative scheduling](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#iterative-sequences)
34+
feature.
35+
36+
## Prerequisites
37+
38+
Before getting started with this tutorial, make sure you're familiar
39+
with the following concepts:
40+
41+
* [Triton-Server Quick Start](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/getting_started/quickstart.html)
42+
* [Python Backend](https://github.com/triton-inference-server/python_backend)
43+
44+
## Iterative Scheduling
45+
46+
Iterative scheduling is a technique that allows the Triton Inference Server to
47+
schedule the same request multiple times with the same input. This is useful for
48+
models that have an auto-regressive loop. Iterative scheduling enables Triton
49+
Server to implement inflight batching for your models and gives you the ability
50+
to combine new sequences as they are arriving with inflight sequences.
51+
52+
## Tutorial Overview
53+
54+
In this tutorial we deploy two models:
55+
56+
* simple-gpt2: This model receives a batch of requests and proceeds to the next
57+
batch only when it is done generating tokens for the current batch.
58+
59+
* iterative-gpt2: This model uses iterative scheduling to process
60+
new sequences in a batch even when it is still generating tokens for the
61+
previous sequences
62+
63+
### Demo
64+
65+
[![asciicast](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx.svg)](https://asciinema.org/a/TUZtHwZsYrJzHuZF7XCOj1Avx)
66+
67+
### Step 1: Prepare the Server Environment
68+
69+
* First, run the Triton Inference Server Container:
70+
71+
```
72+
# Replace yy.mm with year and month of release. Please use 24.04 release upward.
73+
docker run --gpus=all --name iterative-scheduling -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v ${PWD}:/workspace/ -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:yy.mm-py3 bash
74+
```
75+
76+
* Next, install all the dependencies required by the models running in the
77+
python backend and login with your [huggingface token](https://huggingface.co/settings/tokens)
78+
(Account on [HuggingFace](https://huggingface.co/) is required).
79+
80+
```
81+
pip install transformers[torch]
82+
```
83+
84+
> [!NOTE]
85+
> Optional: If you want to avoid installing the dependencies each time you run the
86+
> container, you can run `docker commit iterative-scheduling iterative-scheduling-image` to save the container
87+
> and use that for subsequent runs.
88+
89+
Then, start the server:
90+
91+
```
92+
tritonserver --model-repository=/models
93+
```
94+
95+
### Step 2: Install the client side dependencies
96+
97+
In another terminal install the client dependencies:
98+
99+
```
100+
pip3 install tritonclient[grpc]
101+
pip3 install tqdm
102+
```
103+
104+
### Step 3: Run the client
105+
106+
The simple-gpt2 model doesn't use iterative scheduling and will proceed to the
107+
next batch only when it is done generating tokens for the current batch.
108+
109+
Run the following command to start the client:
110+
111+
```
112+
python3 client/client.py --model simple-gpt2
113+
```
114+
115+
As you can see, the tokens for one request are processed first before proceeding
116+
to the next request.
117+
118+
Run `Ctrl+C` to stop the client.
119+
120+
121+
The iterative scheduler is able to incorporate new requests as they are arriving
122+
in the server.
123+
124+
Run the following command to start the client:
125+
```
126+
python3 client/client.py --model iterative-gpt2
127+
```
128+
129+
As you can see, the tokens for both prompts are getting generated simultaneously.
130+
131+
## Next Steps
132+
133+
We plan to integrate KV-Cache with these models for better performance. Currently,
134+
the main goal of tutorial is to demonstrate how to use iterative scheduling with
135+
Python backend.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import argparse
28+
import threading
29+
import time
30+
from functools import partial
31+
32+
import numpy as np
33+
import tritonclient.grpc as grpcclient
34+
from print_utils import Display
35+
36+
37+
def client1_callback(display, event, result, error):
38+
if error:
39+
raise error
40+
41+
display.update_top()
42+
if result.get_response().parameters.get("triton_final_response").bool_param:
43+
event.set()
44+
45+
46+
def client2_callback(display, event, result, error):
47+
if error:
48+
raise error
49+
50+
display.update_bottom()
51+
if result.get_response().parameters.get("triton_final_response").bool_param:
52+
event.set()
53+
54+
55+
def run_inferences(url, model_name, display, max_tokens):
56+
# Create clients
57+
client1 = grpcclient.InferenceServerClient(url)
58+
client2 = grpcclient.InferenceServerClient(url)
59+
60+
inputs0 = []
61+
prompt1 = "Programming in C++ is like"
62+
inputs0.append(grpcclient.InferInput("text_input", [1, 1], "BYTES"))
63+
inputs0[0].set_data_from_numpy(np.array([[prompt1]], dtype=np.object_))
64+
65+
prompt2 = "Programming in Assembly is like"
66+
inputs1 = []
67+
inputs1.append(grpcclient.InferInput("text_input", [1, 1], "BYTES"))
68+
inputs1[0].set_data_from_numpy(np.array([[prompt2]], dtype=np.object_))
69+
70+
event1 = threading.Event()
71+
event2 = threading.Event()
72+
client1.start_stream(callback=partial(partial(client1_callback, display), event1))
73+
client2.start_stream(callback=partial(partial(client2_callback, display), event2))
74+
75+
while True:
76+
# Reset the events
77+
event1.clear()
78+
event2.clear()
79+
80+
# Setup the display initially with the prompts
81+
display.clear()
82+
parameters = {"ignore_eos": True, "max_tokens": max_tokens}
83+
84+
client1.async_stream_infer(
85+
model_name=model_name,
86+
inputs=inputs0,
87+
enable_empty_final_response=True,
88+
parameters=parameters,
89+
)
90+
91+
# Add a small delay so that the two requests are not sent at the same
92+
# time
93+
time.sleep(0.05)
94+
client2.async_stream_infer(
95+
model_name=model_name,
96+
inputs=inputs1,
97+
enable_empty_final_response=True,
98+
parameters=parameters,
99+
)
100+
101+
event1.wait()
102+
event2.wait()
103+
time.sleep(2)
104+
105+
106+
if __name__ == "__main__":
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument("--url", type=str, default="localhost:8001")
109+
parser.add_argument("--model", type=str, default="simple-gpt2")
110+
parser.add_argument("--max-tokens", type=int, default=128)
111+
args = parser.parse_args()
112+
display = Display(args.max_tokens)
113+
114+
run_inferences(args.url, args.model, display, args.max_tokens)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from tqdm import tqdm
28+
29+
30+
class Display:
31+
def __init__(self, max_tokens) -> None:
32+
self._top = tqdm(position=0, total=max_tokens, miniters=1)
33+
self._bottom = tqdm(position=1, total=max_tokens, miniters=1)
34+
self._max_tokens = max_tokens
35+
36+
def update_top(self):
37+
self._top.update(1)
38+
self._top.refresh()
39+
40+
def update_bottom(self):
41+
self._bottom.update(1)
42+
self._bottom.refresh()
43+
44+
def clear(self):
45+
self._top.reset()
46+
self._bottom.reset()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"data":
3+
[
4+
{
5+
"input": ["machine learning is"]
6+
}
7+
]
8+
}

0 commit comments

Comments
 (0)