Skip to content

Commit 22dba92

Browse files
authored
feat: update v1.2 (#31)
1 parent d76f2c2 commit 22dba92

File tree

94 files changed

+310103
-977
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+310103
-977
lines changed

CMakeLists.txt

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ project(tritonfastertransformerbackend LANGUAGES C CXX)
3333
#
3434
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
3535
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
36+
option(BUILD_MULTI_GPU "Enable multi GPU support" ON)
37+
3638
set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes")
3739
set(TRITON_PYTORCH_LIB_PATHS "" CACHE PATH "Paths to Torch libraries")
3840

@@ -44,8 +46,6 @@ if(NOT CMAKE_BUILD_TYPE)
4446
set(CMAKE_BUILD_TYPE Release)
4547
endif()
4648

47-
set(BUILD_MULTI_GPU "ON")
48-
message("-- Enable BUILD_MULTI_GPU")
4949
set(USE_TRITONSERVER_DATATYPE "ON")
5050
message("-- Enable USE_TRITONSERVER_DATATYPE")
5151

@@ -56,10 +56,15 @@ find_package(Python3 REQUIRED COMPONENTS Development)
5656

5757
find_package(FasterTransformer)
5858
find_package(CUDA 10.1 REQUIRED)
59-
find_package(MPI REQUIRED)
60-
find_package(NCCL REQUIRED)
61-
62-
message(STATUS "Found MPI (include: ${MPI_INCLUDE_DIRS}, library: ${MPI_LIBRARIES})")
59+
if (BUILD_MULTI_GPU)
60+
message(STATUS "Enable BUILD_MULTI_GPU.")
61+
find_package(MPI REQUIRED)
62+
find_package(NCCL REQUIRED)
63+
message(STATUS "Found MPI (include: ${MPI_INCLUDE_DIRS}, library: ${MPI_LIBRARIES})")
64+
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
65+
else()
66+
message(STATUS "Disable BUILD_MULTI_GPU.")
67+
endif()
6368

6469
if (${CUDA_VERSION} GREATER_EQUAL 11.0)
6570
message(STATUS "Add DCUDA11_MODE")
@@ -97,12 +102,19 @@ FetchContent_Declare(
97102
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
98103
GIT_SHALLOW ON
99104
)
100-
FetchContent_Declare(
101-
repo-ft
102-
GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git
103-
GIT_TAG main
104-
GIT_SHALLOW ON
105-
)
105+
if (EXISTS ${FT_DIR})
106+
FetchContent_Declare(
107+
repo-ft
108+
SOURCE_DIR ${FT_DIR}
109+
)
110+
else()
111+
FetchContent_Declare(
112+
repo-ft
113+
GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git
114+
GIT_TAG v5.1
115+
GIT_SHALLOW ON
116+
)
117+
endif()
106118
FetchContent_MakeAvailable(repo-common repo-core repo-backend repo-ft)
107119

108120
#
@@ -128,12 +140,10 @@ add_library(
128140

129141
#find_package(CUDAToolkit REQUIRED)
130142
find_package(CUDA 10.1 REQUIRED)
131-
find_package(MPI REQUIRED)
132-
##find_package(NCCL REQUIRED)
133-
#if (${CUDA_VERSION} GREATER_EQUAL 11.0)
134-
message(STATUS "Add DCUDA11_MODE")
135-
add_definitions("-DCUDA11_MODE")
136-
#endif()
143+
if (${CUDA_VERSION} GREATER_EQUAL 11.0)
144+
message(STATUS "Add DCUDA11_MODE")
145+
add_definitions("-DCUDA11_MODE")
146+
endif()
137147

138148
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
139149

@@ -148,7 +158,6 @@ target_include_directories(
148158
${CMAKE_CURRENT_SOURCE_DIR}/src
149159
${TRITON_PYTORCH_INCLUDE_PATHS}
150160
${Python3_INCLUDE_DIRS}
151-
${MPI_INCLUDE_PATH}
152161
${repo-ft_SOURCE_DIR}
153162
${repo-core_SOURCE_DIR}/include
154163
)
@@ -157,8 +166,6 @@ target_link_directories(
157166
triton-fastertransformer-backend
158167
PRIVATE
159168
${CUDA_PATH}/lib64
160-
${MPI_Libraries}
161-
/usr/local/mpi/lib
162169
)
163170

164171
target_compile_features(triton-fastertransformer-backend PRIVATE cxx_std_14)
@@ -210,14 +217,37 @@ target_link_libraries(
210217
triton-backend-utils # from repo-backend
211218
transformer-shared # from repo-ft
212219
${TRITON_PYTORCH_LDFLAGS}
213-
${NCCL_LIBRARIES}
214-
${MPI_LIBRARIES}
215220
-lcublas
216221
-lcublasLt
217222
-lcudart
218223
-lcurand
219224
)
220225

226+
if (BUILD_MULTI_GPU)
227+
target_compile_definitions(
228+
triton-fastertransformer-backend
229+
PUBLIC
230+
BUILD_MULTI_GPU
231+
)
232+
target_include_directories(
233+
triton-fastertransformer-backend
234+
PRIVATE
235+
${MPI_INCLUDE_PATH}
236+
)
237+
target_link_directories(
238+
triton-fastertransformer-backend
239+
PRIVATE
240+
${MPI_Libraries}
241+
/usr/local/mpi/lib
242+
)
243+
target_link_libraries(
244+
triton-fastertransformer-backend
245+
PRIVATE
246+
${NCCL_LIBRARIES}
247+
${MPI_LIBRARIES}
248+
)
249+
endif()
250+
221251
if(${TRITON_ENABLE_GPU})
222252
target_link_libraries(
223253
triton-fastertransformer-backend

README.md

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,39 @@
2828

2929
# FasterTransformer Backend
3030

31-
The Triton backend for the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA. In the FasterTransformer v4.0, it supports multi-gpu inference on GPT-3 model. This backend integrates FasterTransformer into Triton to use giant GPT-3 model serving by Triton. In the below example, we will show how to use the FasterTransformer backend in Triton to run inference on a GPT-3 model with 345M parameters trained by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). In latest beta release, FasterTransformer backend supports the multi-node multi-GPU inference on T5 with the model of huggingface.
31+
The Triton backend for the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA. In the FasterTransformer v4.0, it supports multi-gpu inference on GPT-3 model. This backend integrates FasterTransformer into Triton to use giant GPT-3 model serving by Triton. In the below example, we will show how to use the FasterTransformer backend in Triton to run inference on a GPT-3 model with 345M parameters trained by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). In latest release, FasterTransformer backend supports the multi-node multi-GPU inference on T5 with the model of huggingface.
3232

3333
Note that this is a research and prototyping tool, not a formal product or maintained framework. User can learn more about Triton backends in the [backend repo](https://github.com/triton-inference-server/backend). Ask questions or report problems on the [issues page](https://github.com/triton-inference-server/fastertransformer_backend/issues) in this FasterTransformer_backend repo.
3434

3535
## Table Of Contents
3636

3737
- [FasterTransformer Backend](#fastertransformer-backend)
3838
- [Table Of Contents](#table-of-contents)
39+
- [Support matrix](#support-matrix)
3940
- [Introduction](#introduction)
4041
- [Setup](#setup)
4142
- [Prepare docker images](#prepare-docker-images)
4243
- [Rebuilding FasterTransformer backend (optional)](#rebuilding-fastertransformer-backend-optional)
4344
- [NCCL_LAUNCH_MODE](#nccl_launch_mode)
4445
- [GPUs Topology](#gpus-topology)
45-
- [MPI Launching with Tensor Parallel size and Pipeline Parallel Size Setting](#mpi-launching-with-tensor-parallel-size-and-pipeline-parallel-size-setting)
46+
- [Model-Parallism and Triton-Multiple-Model-Instances](#model-parallism-and-triton-multiple-model-instances)
47+
- [Run inter-node (T x P > GPUs per Node) models](#run-inter-node-t-x-p--gpus-per-node-models)
48+
- [Run intra-node (T x P <= GPUs per Node) models](#run-intra-node-t-x-p--gpus-per-node-models)
49+
- [Specify Multiple Model Instances](#specify-multiple-model-instances)
50+
- [Multi-Node Inference](#multi-node-inference)
4651
- [Request examples](#request-examples)
4752
- [Changelog](#changelog)
4853

54+
## Support matrix
55+
56+
| Models | FP16 | BF16 | Tensor parallel | Pipeline parallel |
57+
| -------- | ---- | ---- | --------------- | ----------------- |
58+
| GPT | Yes | Yes | Yes | Yes |
59+
| GPT-J | Yes | Yes | Yes | Yes |
60+
| T5 | Yes | Yes | Yes | Yes |
61+
| GPT-NeoX | Yes | Yes | Yes | Yes |
62+
| BERT | Yes | Yes | Yes | Yes |
63+
4964
## Introduction
5065

5166
FasterTransformer backend hopes to integrate the FasterTransformer into Triton, leveraging the efficiency of FasterTransformer and serving capabilities of Triton. To run the GPT-3 model, we need to solve the following two issues: 1. How to run the auto-regressive model? 2. How to run the model with multi-gpu and multi-node?
@@ -84,10 +99,10 @@ For the issue of running the model with multi-gpu and multi-node, FasterTransfor
8499
## Setup
85100

86101
```bash
102+
git clone https://github.com/triton-inference-server/fastertransformer_backend.git
103+
cd fastertransformer_backend
87104
export WORKSPACE=$(pwd)
88-
export SRC_MODELS_DIR=${WORKSPACE}/models
89-
export TRITON_MODELS_STORE=${WORKSPACE}/triton-model-store
90-
export CONTAINER_VERSION=22.03
105+
export CONTAINER_VERSION=22.07
91106
export TRITON_DOCKER_IMAGE=triton_with_ft:${CONTAINER_VERSION}
92107
```
93108

@@ -97,12 +112,6 @@ The current official Triton Inference Server docker image doesn't contain
97112
FasterTransformer backend, thus the users must prepare own docker image using below command:
98113

99114
```bash
100-
cd ${WORKSPACE}
101-
git clone https://github.com/triton-inference-server/fastertransformer_backend
102-
git clone https://github.com/triton-inference-server/server.git # We need some tools when we test this backend
103-
git clone https://github.com/NVIDIA/FasterTransformer.git # Used for convert the checkpoint and triton output
104-
ln -s server/qa/common .
105-
cd fastertransformer_backend
106115
docker build --rm \
107116
--build-arg TRITON_VERSION=${CONTAINER_VERSION} \
108117
-t ${TRITON_DOCKER_IMAGE} \
@@ -180,22 +189,50 @@ If your current machine/nodes are fully connected through PCIE or even across NU
180189
If you met timed-out or hangs, please first check the topology and try to use DGX V100 or DGX A100 with nvlink connected.
181190

182191

183-
## MPI Launching with Tensor Parallel size and Pipeline Parallel Size Setting
184-
192+
## Model-Parallism and Triton-Multiple-Model-Instances
185193
We apply MPI to start single-node/multi-node servers.
186194

187195
- N: Number of MPI Processes/Number of Nodes
188-
- T: Tensor Parallel Size. Default 8
196+
- T: Tensor Parallel Size. Default 1
189197
- P: Pipeline Parallel Size. Default 1
190198

191-
`total number of gpus = num_gpus_per_node x N = T x P`
199+
Multiple model instances on same GPUs will share the weights, so there will not be any redundant weights memory allocated.
200+
201+
### Run inter-node (T x P > GPUs per Node) models
202+
- `total number of GPUs = num_gpus_per_node x N = T x P`.
203+
- only single mode instance supported
204+
205+
### Run intra-node (T x P <= GPUs per Node) models
206+
- `total number of visible GPUs must be evenly divisble by T x P`. Note that you can control this by setting `CUDA_VISIBLE_DEVICES`.
207+
- `total number of visible GPUs must be <= T x P x Instance Count`. It can avoid unnecessary cuda memory allocation on unused GPUs.
208+
- multiple model instances can be run on tsame GPU groups or different GPU groups.
209+
210+
The backend will first try to assign different GPU groups to different model instances. If there are not empty GPUs, multiple model instances will be assigned to the same GPU groups.
211+
212+
For example, if there are 8 GPUs, 8 model instances (T = 2, P = 1), then model instances will be distributed to GPU groups [0, 1], [2, 3], [4, 5], [6, 7], [0, 1], [2, 3], [4, 5], [6, 7].
213+
- weights are shared among model instances in same GPU groups. In the example above, instance 0 and instance 4 will share the same weights, and others are similar.
192214

193-
**Note** that we currently do not support the case that different nodes have different number of GPUs.
215+
### Specify Multiple Model Instances
216+
217+
Set `count` here to start multiple model instances. Note `KIND_CPU` is the only choice here as the backend needs to take full control of how to distribute multiple model instances to all the visible GPUs.
218+
219+
```json
220+
instance_group [
221+
{
222+
count: 8
223+
kind: KIND_CPU
224+
}
225+
]
226+
```
227+
228+
### Multi-Node Inference
229+
230+
We currently do not support the case that different nodes have different number of GPUs.
194231

195232
We start one MPI process per node. If you need to run on three nodes, then you should launch 3 Nodes with one process per node.
196233
Remember to change `tensor_para_size` and `pipeline_para_size` if you run on multiple nodes.
197234

198-
We do suggest tensor_para_size = number of gpus in one node (e.g. 8 for DGX A100), and pipeline_para_size = number of nodes (2 for two nodes). Other model configuration in config.pbtxt should be modified as normal.
235+
We do suggest tensor_para_size = number of GPUs in one node (e.g. 8 for DGX A100), and pipeline_para_size = number of nodes (2 for two nodes). Other model configuration in config.pbtxt should be modified as normal.
199236

200237
## Request examples
201238

@@ -205,6 +242,22 @@ Specifically `tools/issue_request.py` is a simple script that sends a request co
205242

206243
## Changelog
207244

245+
Aug 2022
246+
- Support for interactive generation
247+
248+
July 2022
249+
- Support shared context optimization in GPT model
250+
- Support UL2
251+
252+
June 2022
253+
- Support decoupled (streaming) mode.
254+
- Add demo of grpc protocol.
255+
- Support BERT
256+
257+
May 2022
258+
- Support GPT-NeoX.
259+
- Support optional input. (triton version must be after 22.05)
260+
208261
April 2022
209262
- Support bfloat16 inference in GPT model.
210263
- Support Nemo Megatron T5 and Megatron-LM T5 model.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[bert]
2+
model_name = bert
3+
position_embedding_type = absolute
4+
hidden_size = 768
5+
num_layer = 12
6+
head_num = 12
7+
size_per_head = 64
8+
activation_type = gelu
9+
inter_size = 3072
10+
max_position_embeddings = 512
11+
layer_norm_eps = 1e-12
12+
weight_data_type = fp32
13+
tensor_para_size = 1
14+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[bert]
2+
model_name = bert
3+
position_embedding_type = absolute
4+
hidden_size = 768
5+
num_layer = 12
6+
head_num = 12
7+
size_per_head = 64
8+
activation_type = gelu
9+
inter_size = 3072
10+
max_position_embeddings = 512
11+
layer_norm_eps = 1e-12
12+
weight_data_type = fp32
13+
tensor_para_size = 2
14+

0 commit comments

Comments
 (0)