Skip to content

Commit 5307dc5

Browse files
authored
Update GQA benchmark to support bfloat16 (#26898)
Update GQA benchmark to support bfloat16 and default to testing the first configuration (fast mode). Note that test_sparse_attention.py was removed in #23547. It is referenced by the benchmark script, so I add it back and disable the test in pipeline mode. Example output from H200 GPU: ``` prompt-sm90-Llama3-8B-b1-h32_8x128-float16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.781751 0.571226 1 32.0 0.893813 0.684198 2 64.0 1.434056 1.589263 3 128.0 1.142192 1.681969 4 256.0 1.503483 2.225498 5 512.0 1.045732 1.878660 6 1024.0 2.334924 0.916745 7 2048.0 2.229924 3.001290 8 4096.0 4.309678 3.198855 9 8192.0 7.932211 7.910411 token-sm90-Llama3-8B-b1-h32_8_d128-float16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 1.751966 0.780081 1 32.0 1.302806 0.043939 2 64.0 2.301024 2.207282 3 128.0 2.294556 3.010107 4 256.0 2.931330 1.781768 5 512.0 1.210220 2.799579 6 1024.0 2.767142 2.660434 7 2048.0 1.420229 0.091433 8 4096.0 0.860655 0.801022 9 8191.0 0.749525 0.820858 prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 1.085427 0.666664 1 32.0 1.714795 0.931262 2 64.0 1.729093 1.438733 3 128.0 1.071263 2.486135 4 256.0 1.957349 1.342417 5 512.0 1.159680 1.591321 6 1024.0 0.743702 2.035150 7 2048.0 1.452736 1.788801 8 4096.0 4.029917 4.041565 9 8192.0 7.934485 7.931600 token-sm90-Llama3-8B-b1-h32_8_d128-bfloat16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.044354 0.043983 1 32.0 0.040715 0.044061 2 64.0 0.045586 0.044071 3 128.0 0.062204 0.061418 4 256.0 0.074764 4.874854 5 512.0 2.472094 2.102259 6 1024.0 4.911269 1.396149 7 2048.0 4.898032 1.684034 8 4096.0 2.523432 2.192279 9 8191.0 1.651366 3.427370 ```
1 parent db3eb22 commit 5307dc5

File tree

2 files changed

+1229
-39
lines changed

2 files changed

+1229
-39
lines changed

onnxruntime/test/python/transformers/benchmark_gqa.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
def get_plot_algos(sm: int, local_window_size: int | None):
1515
# GQA with local windows only works in sm=8x
1616
if sm >= 80 and local_window_size:
17-
return {
18-
"line_vals": ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"],
19-
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local", "ORT-GQA-Dense-PackedQKV", "ORT-GQA-Local-PackedQKV"],
20-
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")],
21-
}
17+
line_vals = ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"]
18+
line_names = ["ORT-GQA-Dense", "ORT-GQA-Local", "ORT-GQA-Dense-PackedQKV", "ORT-GQA-Local-PackedQKV"]
19+
styles = [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")]
2220
else:
23-
return {
24-
"line_vals": ["ort_gqa", "ort_gqa_packed"],
25-
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"],
26-
"styles": [("red", "solid"), ("blue", "dashed")],
27-
}
21+
line_vals = ["ort_gqa", "ort_gqa_packed"]
22+
line_names = ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"]
23+
styles = [("red", "solid"), ("blue", "dashed")]
24+
25+
return {
26+
"line_vals": line_vals,
27+
"line_names": line_names,
28+
"styles": styles,
29+
}
2830

2931

3032
def plot_prompt_performance(
@@ -37,6 +39,7 @@ def plot_prompt_performance(
3739
max_seq_len: int,
3840
local_window_size: int | None = None,
3941
use_smooth_softmax: bool = False,
42+
dtype: str = "float16",
4043
):
4144
import triton # noqa: PLC0415
4245

@@ -48,14 +51,15 @@ def plot_prompt_performance(
4851
line_arg="provider",
4952
ylabel="ms",
5053
**algos,
51-
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}x{head_size}-fp16",
54+
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}x{head_size}-{dtype}",
5255
args={
5356
"batch_size": batch_size,
5457
"num_heads": num_heads,
5558
"kv_num_heads": kv_num_heads,
5659
"head_size": head_size,
5760
"local_window_size": local_window_size,
5861
"use_smooth_softmax": use_smooth_softmax,
62+
"dtype": dtype,
5963
},
6064
)
6165
]
@@ -70,6 +74,7 @@ def benchmark(
7074
head_size: int,
7175
local_window_size: int | None = None,
7276
use_smooth_softmax: bool = False,
77+
dtype: str = "float16",
7378
device="cuda",
7479
):
7580
warmup = 15
@@ -86,6 +91,7 @@ def benchmark(
8691
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
8792
use_smooth_softmax=use_smooth_softmax,
8893
device=device,
94+
dtype=torch.float16 if dtype == "float16" else torch.bfloat16,
8995
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
9096
)
9197

@@ -107,6 +113,7 @@ def plot_token_performance(
107113
max_seq_len: int,
108114
local_window_size: int | None = None,
109115
use_smooth_softmax: bool = False,
116+
dtype: str = "float16",
110117
):
111118
import triton # noqa: PLC0415
112119

@@ -118,14 +125,15 @@ def plot_token_performance(
118125
line_arg="provider",
119126
ylabel="ms",
120127
**algos,
121-
plot_name=f"token-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}_d{head_size}-fp16",
128+
plot_name=f"token-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}_d{head_size}-{dtype}",
122129
args={
123130
"batch_size": batch_size,
124131
"num_heads": num_heads,
125132
"kv_num_heads": kv_num_heads,
126133
"head_size": head_size,
127134
"local_window_size": local_window_size,
128135
"use_smooth_softmax": use_smooth_softmax,
136+
"dtype": dtype,
129137
},
130138
)
131139
]
@@ -140,6 +148,7 @@ def benchmark(
140148
head_size: int,
141149
local_window_size: int | None = None,
142150
use_smooth_softmax: bool = False,
151+
dtype: str = "float16",
143152
device="cuda",
144153
):
145154
warmup = 15
@@ -158,6 +167,7 @@ def benchmark(
158167
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
159168
use_smooth_softmax=use_smooth_softmax,
160169
device=device,
170+
dtype=torch.float16 if dtype == "float16" else torch.bfloat16,
161171
)
162172

163173
obj = OrtGroupQueryAttention(config)
@@ -168,7 +178,7 @@ def benchmark(
168178
benchmark.run(save_path=".", print_data=True)
169179

170180

171-
def run_performance_test(sm: int):
181+
def run_performance_test(sm: int, fast: bool = False):
172182
"""
173183
Run performance tests for prompt and token generation.
174184
@@ -177,7 +187,7 @@ def run_performance_test(sm: int):
177187
memory_in_gb = torch.cuda.get_device_properties(device_id).total_memory / (1024 * 1024 * 1024)
178188

179189
# Note: some models use bf16.
180-
# We use fp16 for all models in this test since bf16 is not supported in ORT python API.
190+
# We use fp16/bf16 for all models in this test.
181191
configures = [
182192
(32, 128, 8, 8192, None, "Llama3-8B"),
183193
(64, 128, 8, 8192, None, "Llama3-70B"),
@@ -188,34 +198,43 @@ def run_performance_test(sm: int):
188198
(40, 128, 10, 131072, None, "Phi-3-medium-128K"),
189199
]
190200

201+
if fast:
202+
configures = configures[:1]
203+
batch_sizes = [1] if fast else [1, 4]
204+
smooth_softmax_options = [False] if fast else [False, True]
205+
dtypes = ["float16", "bfloat16"]
206+
191207
# Reduce max sequence length when GPU memory is not enough.
192208
threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768
193209

194210
for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures:
195-
for batch_size in [1, 4]:
196-
for use_smooth_softmax in [False, True]:
197-
plot_prompt_performance(
198-
sm=sm,
199-
batch_size=batch_size,
200-
num_heads=num_heads,
201-
kv_num_heads=kv_num_heads,
202-
head_size=head_size,
203-
max_seq_len=min(threshold, max_seq_len),
204-
local_window_size=local_window_size,
205-
use_smooth_softmax=use_smooth_softmax,
206-
model_name=model_name,
207-
)
208-
plot_token_performance(
209-
sm=sm,
210-
batch_size=batch_size,
211-
num_heads=num_heads,
212-
kv_num_heads=kv_num_heads,
213-
head_size=head_size,
214-
max_seq_len=min(threshold, max_seq_len),
215-
local_window_size=local_window_size,
216-
use_smooth_softmax=use_smooth_softmax,
217-
model_name=model_name,
218-
)
211+
for batch_size in batch_sizes:
212+
for use_smooth_softmax in smooth_softmax_options:
213+
for dtype in dtypes:
214+
plot_prompt_performance(
215+
sm=sm,
216+
batch_size=batch_size,
217+
num_heads=num_heads,
218+
kv_num_heads=kv_num_heads,
219+
head_size=head_size,
220+
max_seq_len=min(threshold, max_seq_len),
221+
local_window_size=local_window_size,
222+
use_smooth_softmax=use_smooth_softmax,
223+
model_name=model_name,
224+
dtype=dtype,
225+
)
226+
plot_token_performance(
227+
sm=sm,
228+
batch_size=batch_size,
229+
num_heads=num_heads,
230+
kv_num_heads=kv_num_heads,
231+
head_size=head_size,
232+
max_seq_len=min(threshold, max_seq_len),
233+
local_window_size=local_window_size,
234+
use_smooth_softmax=use_smooth_softmax,
235+
model_name=model_name,
236+
dtype=dtype,
237+
)
219238

220239

221240
if __name__ == "__main__":
@@ -224,4 +243,4 @@ def run_performance_test(sm: int):
224243

225244
s = torch.cuda.Stream()
226245
with torch.cuda.stream(s), torch.no_grad():
227-
run_performance_test(sm)
246+
run_performance_test(sm, fast=True)

0 commit comments

Comments
 (0)