Skip to content

Commit bded432

Browse files
authored
update vllm kernels benchmark files for flash attn and fused moe (vllm-project#176)
1 parent f04d949 commit bded432

6 files changed

Lines changed: 1776 additions & 0 deletions
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# isort: off
5+
import gc
6+
7+
import torch
8+
import triton
9+
10+
from benchmark.src.flash_attn_interface_ import (
11+
flash_attn_varlen_func_CalKernelTime)
12+
from benchmark.src.get_model_config import (
13+
gen_cutlass_flash_attn_decode_correctness_configs as
14+
gen_correctness_config)
15+
from benchmark.src.get_model_config import (
16+
gen_cutlass_flash_attn_decode_perf_configs as gen_perf_configs)
17+
from tests.flash_attn.test_flash_attn_varlen_func import ref_paged_attn
18+
from tests.utils import parse_args, seed_everything
19+
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
20+
# isort: on
21+
22+
DEVICE = "xpu"
23+
24+
25+
def clear_xpu_cache():
26+
torch.xpu.empty_cache()
27+
gc.collect()
28+
torch.xpu.synchronize()
29+
30+
31+
def calculate_memory_usage(kv_len_sum, num_kv_heads, head_size, output_dtype):
32+
# Memory for key and value caches
33+
kv_cache_memory = 2 * kv_len_sum * num_kv_heads * \
34+
head_size * torch.tensor([], dtype=output_dtype).element_size()
35+
return kv_cache_memory / (1024**3) # Convert to GB
36+
37+
38+
def make_decode_with_paged_kv_input(config):
39+
seq_lens, num_heads, head_size, block_size, \
40+
output_dtype, _, num_blocks, _, q_dtype, is_sink = config
41+
# if num_heads == (16, 1) and head_size == 256:
42+
# pytest.skip("skip test cases that may run out of SLM.")
43+
num_seqs = int(seq_lens.split(",")[0])
44+
query_lens = list(map(lambda x: int(x), seq_lens.split(",")[1].split("+")))
45+
kv_lens = list(map(lambda x: int(x), seq_lens.split(",")[2].split("+")))
46+
num_query_heads = num_heads[0]
47+
num_kv_heads = num_heads[1]
48+
assert num_query_heads % num_kv_heads == 0
49+
max_query_len = max(query_lens)
50+
max_kv_len = max(kv_lens)
51+
scale = head_size**-0.5
52+
53+
query = torch.randn(sum(query_lens),
54+
num_query_heads,
55+
head_size,
56+
dtype=output_dtype)
57+
key_cache = torch.randn(num_blocks,
58+
block_size,
59+
num_kv_heads,
60+
head_size,
61+
dtype=output_dtype)
62+
value_cache = torch.randn_like(key_cache)
63+
cu_query_lens = torch.tensor([0] + query_lens,
64+
dtype=torch.int32).cumsum(dim=0,
65+
dtype=torch.int32)
66+
67+
seq_k = torch.tensor(kv_lens, dtype=torch.int32)
68+
69+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
70+
block_tables = torch.randint(0,
71+
num_blocks,
72+
(num_seqs, max_num_blocks_per_seq),
73+
dtype=torch.int32)
74+
sink = None
75+
if is_sink:
76+
sink = torch.randn(num_query_heads, dtype=output_dtype)
77+
78+
maybe_quantized_query = query
79+
maybe_quantized_key_cache = key_cache
80+
maybe_quantized_value_cache = value_cache
81+
q_descale = None #noqa: F841
82+
k_descale = None #noqa: F841
83+
v_descale = None #noqa: F841
84+
if q_dtype is not None:
85+
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
86+
maybe_quantized_query = query.to(q_dtype)
87+
maybe_quantized_key_cache = key_cache.to(q_dtype)
88+
maybe_quantized_value_cache = value_cache.to(q_dtype)
89+
90+
scale_shape = (num_seqs, num_kv_heads)
91+
q_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
92+
k_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
93+
v_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
94+
return maybe_quantized_query, maybe_quantized_key_cache, \
95+
maybe_quantized_value_cache, max_query_len, cu_query_lens, \
96+
max_kv_len, seq_k, scale, block_tables, sink, query, \
97+
key_cache, value_cache, query_lens, kv_lens
98+
99+
100+
def calculate_diff_decode_paged_kv(config):
101+
_, _, _, _, _, _, _, _, q_dtype, _ = config
102+
maybe_quantized_query, maybe_quantized_key_cache, \
103+
maybe_quantized_value_cache, max_query_len, cu_query_lens, \
104+
max_kv_len, seq_k, scale, block_tables, sink, query, \
105+
key_cache, value_cache, query_lens, kv_lens = \
106+
make_decode_with_paged_kv_input(config)
107+
108+
output = flash_attn_varlen_func(maybe_quantized_query,
109+
maybe_quantized_key_cache,
110+
maybe_quantized_value_cache,
111+
max_query_len,
112+
cu_query_lens,
113+
max_kv_len,
114+
seqused_k=seq_k,
115+
softmax_scale=scale,
116+
causal=False,
117+
block_table=block_tables,
118+
window_size=(-1, -1),
119+
s_aux=sink)
120+
121+
ref_output = ref_paged_attn(query=query,
122+
key_cache=key_cache,
123+
value_cache=value_cache,
124+
query_lens=query_lens,
125+
kv_lens=kv_lens,
126+
block_tables=block_tables,
127+
scale=scale,
128+
casual=False,
129+
is_paged=True,
130+
sink=sink,
131+
window_size_left=-1,
132+
window_size_right=-1)
133+
atol, rtol = 1e-2, 1e-2
134+
if q_dtype is not None:
135+
atol, rtol = 1.5e-1, 1.5e-1
136+
try:
137+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
138+
f"{torch.max(torch.abs(output - ref_output))}"
139+
print("✅ All implementations match, ", config)
140+
except AssertionError as e:
141+
print("❌ Implementations differ, ", config, " error: ", e)
142+
143+
144+
def benchmark_decode_with_paged_kv(seq_lens, num_heads, head_size, block_size,
145+
output_dtype, soft_cap, num_blocks,
146+
fa_versions, q_dtype, is_sink, provider,
147+
iterations):
148+
maybe_quantized_query, maybe_quantized_key_cache, \
149+
maybe_quantized_value_cache, max_query_len, cu_query_lens, \
150+
max_kv_len, seq_k, scale, block_tables, sink, _, \
151+
_, _, _, _ = make_decode_with_paged_kv_input(
152+
config=(seq_lens, num_heads, head_size,
153+
block_size, output_dtype, soft_cap,
154+
num_blocks, fa_versions, q_dtype, is_sink))
155+
156+
num_seqs = int(seq_lens.split(",")[0])
157+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
158+
159+
print(f"Running config: {seq_lens, num_heads, head_size, \
160+
block_size, output_dtype, soft_cap, num_blocks, \
161+
fa_versions, q_dtype, \
162+
is_sink}, Provider: {provider}",
163+
flush=True)
164+
assert iterations > 5, \
165+
"Number of iterations should be greater than 5 to account for warmup"
166+
start = torch.xpu.Event(enable_timing=True)
167+
end = torch.xpu.Event(enable_timing=True)
168+
total_latency = 0.0
169+
ms = 0.0
170+
queries = [
171+
torch.rand_like(maybe_quantized_query) for _ in range(iterations)
172+
]
173+
174+
if provider == "flash":
175+
for index in range(iterations):
176+
block_tables = torch.randint(0,
177+
num_blocks,
178+
(num_seqs, max_num_blocks_per_seq),
179+
dtype=torch.int32)
180+
start.record()
181+
flash_attn_varlen_func(queries[index],
182+
maybe_quantized_key_cache,
183+
maybe_quantized_value_cache,
184+
max_query_len,
185+
cu_query_lens,
186+
max_kv_len,
187+
seqused_k=seq_k,
188+
softmax_scale=scale,
189+
causal=False,
190+
block_table=block_tables,
191+
window_size=(-1, -1),
192+
s_aux=sink)
193+
end.record()
194+
end.synchronize()
195+
if index >= 5: # skip the first 5 iterations for warmup
196+
total_latency += start.elapsed_time(end)
197+
else:
198+
for index in range(iterations):
199+
block_tables = torch.randint(0,
200+
num_blocks,
201+
(num_seqs, max_num_blocks_per_seq),
202+
dtype=torch.int32)
203+
flash_attn_varlen_func_CalKernelTime(queries[index],
204+
maybe_quantized_key_cache,
205+
maybe_quantized_value_cache,
206+
max_query_len,
207+
cu_query_lens,
208+
max_kv_len,
209+
seqused_k=seq_k,
210+
softmax_scale=scale,
211+
causal=False,
212+
block_table=block_tables,
213+
window_size=(-1, -1),
214+
s_aux=sink,
215+
start_event=start,
216+
end_event=end)
217+
if index >= 5: # skip the first 5 iterations for warmup
218+
total_latency += start.elapsed_time(end)
219+
if provider == "flash_memBandwidth":
220+
torch.xpu.synchronize()
221+
ms = total_latency / (iterations - 5)
222+
memory_load_GB = calculate_memory_usage(seq_k.sum().item(),
223+
num_heads[1], head_size,
224+
output_dtype)
225+
clear_xpu_cache()
226+
return memory_load_GB / (ms / 1000)
227+
torch.xpu.synchronize()
228+
ms = total_latency / (iterations - 5)
229+
clear_xpu_cache()
230+
231+
return 1000 * ms
232+
233+
234+
def get_benchmark_decode_with_paged_kv(iterations=20):
235+
236+
@triton.testing.perf_report(
237+
triton.testing.Benchmark(
238+
x_names=[
239+
"seq_lens", "num_heads", "head_size", "block_size",
240+
"output_dtype", "soft_cap", "num_blocks", "fa_versions",
241+
"q_dtype", "is_sink"
242+
],
243+
x_vals=[tuple(c) for c in configs],
244+
line_arg="provider",
245+
line_vals=["flash", "flash_kernelTime", "flash_memBandwidth"],
246+
line_names=[
247+
"FlashAttention(us)", "FlashAttention_kernelTime(us)",
248+
"FlashAttention_memBandwidth(GB/s)"
249+
],
250+
styles=[("blue", "-"), ("green", "-"), ("purple", "-")],
251+
ylabel="Latency (us)",
252+
plot_name="flash-attn-decode",
253+
args={},
254+
))
255+
def benchmark(seq_lens, num_heads, head_size, block_size, output_dtype,
256+
soft_cap, num_blocks, fa_versions, q_dtype, is_sink,
257+
provider):
258+
return benchmark_decode_with_paged_kv(seq_lens=seq_lens,
259+
num_heads=num_heads,
260+
head_size=head_size,
261+
block_size=block_size,
262+
output_dtype=output_dtype,
263+
soft_cap=soft_cap,
264+
num_blocks=num_blocks,
265+
fa_versions=fa_versions,
266+
q_dtype=q_dtype,
267+
is_sink=is_sink,
268+
provider=provider,
269+
iterations=iterations)
270+
271+
return benchmark
272+
273+
274+
def filter_configs(configs):
275+
new_configs = []
276+
for config in configs:
277+
if (config[1] == (16, 1) and config[2] == 256) or \
278+
(config[3] == 128 and config[6] == 32768 and config[2] >= 192):
279+
print("Skipping config due to potential OOM: ", config)
280+
continue
281+
new_configs.append(config)
282+
return new_configs
283+
284+
285+
if __name__ == "__main__":
286+
287+
args = parse_args()
288+
seed = 1234
289+
seed_everything(seed)
290+
iterations = 20
291+
torch.set_default_device("xpu")
292+
torch.xpu.set_device("xpu:0")
293+
294+
configs = gen_correctness_config()
295+
configs = filter_configs(configs)
296+
for config in configs:
297+
try:
298+
calculate_diff_decode_paged_kv(config)
299+
except Exception as e:
300+
print("Error in config: ", config, " error: ", e)
301+
clear_xpu_cache()
302+
303+
configs = gen_perf_configs()
304+
configs = filter_configs(configs)
305+
benchmark = get_benchmark_decode_with_paged_kv(iterations=iterations)
306+
# Run performance benchmark
307+
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)