|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# isort: off |
| 5 | +import gc |
| 6 | + |
| 7 | +import torch |
| 8 | +import triton |
| 9 | + |
| 10 | +from benchmark.src.flash_attn_interface_ import ( |
| 11 | + flash_attn_varlen_func_CalKernelTime) |
| 12 | +from benchmark.src.get_model_config import ( |
| 13 | + gen_cutlass_flash_attn_decode_correctness_configs as |
| 14 | + gen_correctness_config) |
| 15 | +from benchmark.src.get_model_config import ( |
| 16 | + gen_cutlass_flash_attn_decode_perf_configs as gen_perf_configs) |
| 17 | +from tests.flash_attn.test_flash_attn_varlen_func import ref_paged_attn |
| 18 | +from tests.utils import parse_args, seed_everything |
| 19 | +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func |
| 20 | +# isort: on |
| 21 | + |
| 22 | +DEVICE = "xpu" |
| 23 | + |
| 24 | + |
| 25 | +def clear_xpu_cache(): |
| 26 | + torch.xpu.empty_cache() |
| 27 | + gc.collect() |
| 28 | + torch.xpu.synchronize() |
| 29 | + |
| 30 | + |
| 31 | +def calculate_memory_usage(kv_len_sum, num_kv_heads, head_size, output_dtype): |
| 32 | + # Memory for key and value caches |
| 33 | + kv_cache_memory = 2 * kv_len_sum * num_kv_heads * \ |
| 34 | + head_size * torch.tensor([], dtype=output_dtype).element_size() |
| 35 | + return kv_cache_memory / (1024**3) # Convert to GB |
| 36 | + |
| 37 | + |
| 38 | +def make_decode_with_paged_kv_input(config): |
| 39 | + seq_lens, num_heads, head_size, block_size, \ |
| 40 | + output_dtype, _, num_blocks, _, q_dtype, is_sink = config |
| 41 | + # if num_heads == (16, 1) and head_size == 256: |
| 42 | + # pytest.skip("skip test cases that may run out of SLM.") |
| 43 | + num_seqs = int(seq_lens.split(",")[0]) |
| 44 | + query_lens = list(map(lambda x: int(x), seq_lens.split(",")[1].split("+"))) |
| 45 | + kv_lens = list(map(lambda x: int(x), seq_lens.split(",")[2].split("+"))) |
| 46 | + num_query_heads = num_heads[0] |
| 47 | + num_kv_heads = num_heads[1] |
| 48 | + assert num_query_heads % num_kv_heads == 0 |
| 49 | + max_query_len = max(query_lens) |
| 50 | + max_kv_len = max(kv_lens) |
| 51 | + scale = head_size**-0.5 |
| 52 | + |
| 53 | + query = torch.randn(sum(query_lens), |
| 54 | + num_query_heads, |
| 55 | + head_size, |
| 56 | + dtype=output_dtype) |
| 57 | + key_cache = torch.randn(num_blocks, |
| 58 | + block_size, |
| 59 | + num_kv_heads, |
| 60 | + head_size, |
| 61 | + dtype=output_dtype) |
| 62 | + value_cache = torch.randn_like(key_cache) |
| 63 | + cu_query_lens = torch.tensor([0] + query_lens, |
| 64 | + dtype=torch.int32).cumsum(dim=0, |
| 65 | + dtype=torch.int32) |
| 66 | + |
| 67 | + seq_k = torch.tensor(kv_lens, dtype=torch.int32) |
| 68 | + |
| 69 | + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
| 70 | + block_tables = torch.randint(0, |
| 71 | + num_blocks, |
| 72 | + (num_seqs, max_num_blocks_per_seq), |
| 73 | + dtype=torch.int32) |
| 74 | + sink = None |
| 75 | + if is_sink: |
| 76 | + sink = torch.randn(num_query_heads, dtype=output_dtype) |
| 77 | + |
| 78 | + maybe_quantized_query = query |
| 79 | + maybe_quantized_key_cache = key_cache |
| 80 | + maybe_quantized_value_cache = value_cache |
| 81 | + q_descale = None #noqa: F841 |
| 82 | + k_descale = None #noqa: F841 |
| 83 | + v_descale = None #noqa: F841 |
| 84 | + if q_dtype is not None: |
| 85 | + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor |
| 86 | + maybe_quantized_query = query.to(q_dtype) |
| 87 | + maybe_quantized_key_cache = key_cache.to(q_dtype) |
| 88 | + maybe_quantized_value_cache = value_cache.to(q_dtype) |
| 89 | + |
| 90 | + scale_shape = (num_seqs, num_kv_heads) |
| 91 | + q_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 |
| 92 | + k_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 |
| 93 | + v_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 |
| 94 | + return maybe_quantized_query, maybe_quantized_key_cache, \ |
| 95 | + maybe_quantized_value_cache, max_query_len, cu_query_lens, \ |
| 96 | + max_kv_len, seq_k, scale, block_tables, sink, query, \ |
| 97 | + key_cache, value_cache, query_lens, kv_lens |
| 98 | + |
| 99 | + |
| 100 | +def calculate_diff_decode_paged_kv(config): |
| 101 | + _, _, _, _, _, _, _, _, q_dtype, _ = config |
| 102 | + maybe_quantized_query, maybe_quantized_key_cache, \ |
| 103 | + maybe_quantized_value_cache, max_query_len, cu_query_lens, \ |
| 104 | + max_kv_len, seq_k, scale, block_tables, sink, query, \ |
| 105 | + key_cache, value_cache, query_lens, kv_lens = \ |
| 106 | + make_decode_with_paged_kv_input(config) |
| 107 | + |
| 108 | + output = flash_attn_varlen_func(maybe_quantized_query, |
| 109 | + maybe_quantized_key_cache, |
| 110 | + maybe_quantized_value_cache, |
| 111 | + max_query_len, |
| 112 | + cu_query_lens, |
| 113 | + max_kv_len, |
| 114 | + seqused_k=seq_k, |
| 115 | + softmax_scale=scale, |
| 116 | + causal=False, |
| 117 | + block_table=block_tables, |
| 118 | + window_size=(-1, -1), |
| 119 | + s_aux=sink) |
| 120 | + |
| 121 | + ref_output = ref_paged_attn(query=query, |
| 122 | + key_cache=key_cache, |
| 123 | + value_cache=value_cache, |
| 124 | + query_lens=query_lens, |
| 125 | + kv_lens=kv_lens, |
| 126 | + block_tables=block_tables, |
| 127 | + scale=scale, |
| 128 | + casual=False, |
| 129 | + is_paged=True, |
| 130 | + sink=sink, |
| 131 | + window_size_left=-1, |
| 132 | + window_size_right=-1) |
| 133 | + atol, rtol = 1e-2, 1e-2 |
| 134 | + if q_dtype is not None: |
| 135 | + atol, rtol = 1.5e-1, 1.5e-1 |
| 136 | + try: |
| 137 | + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ |
| 138 | + f"{torch.max(torch.abs(output - ref_output))}" |
| 139 | + print("✅ All implementations match, ", config) |
| 140 | + except AssertionError as e: |
| 141 | + print("❌ Implementations differ, ", config, " error: ", e) |
| 142 | + |
| 143 | + |
| 144 | +def benchmark_decode_with_paged_kv(seq_lens, num_heads, head_size, block_size, |
| 145 | + output_dtype, soft_cap, num_blocks, |
| 146 | + fa_versions, q_dtype, is_sink, provider, |
| 147 | + iterations): |
| 148 | + maybe_quantized_query, maybe_quantized_key_cache, \ |
| 149 | + maybe_quantized_value_cache, max_query_len, cu_query_lens, \ |
| 150 | + max_kv_len, seq_k, scale, block_tables, sink, _, \ |
| 151 | + _, _, _, _ = make_decode_with_paged_kv_input( |
| 152 | + config=(seq_lens, num_heads, head_size, |
| 153 | + block_size, output_dtype, soft_cap, |
| 154 | + num_blocks, fa_versions, q_dtype, is_sink)) |
| 155 | + |
| 156 | + num_seqs = int(seq_lens.split(",")[0]) |
| 157 | + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
| 158 | + |
| 159 | + print(f"Running config: {seq_lens, num_heads, head_size, \ |
| 160 | + block_size, output_dtype, soft_cap, num_blocks, \ |
| 161 | + fa_versions, q_dtype, \ |
| 162 | + is_sink}, Provider: {provider}", |
| 163 | + flush=True) |
| 164 | + assert iterations > 5, \ |
| 165 | + "Number of iterations should be greater than 5 to account for warmup" |
| 166 | + start = torch.xpu.Event(enable_timing=True) |
| 167 | + end = torch.xpu.Event(enable_timing=True) |
| 168 | + total_latency = 0.0 |
| 169 | + ms = 0.0 |
| 170 | + queries = [ |
| 171 | + torch.rand_like(maybe_quantized_query) for _ in range(iterations) |
| 172 | + ] |
| 173 | + |
| 174 | + if provider == "flash": |
| 175 | + for index in range(iterations): |
| 176 | + block_tables = torch.randint(0, |
| 177 | + num_blocks, |
| 178 | + (num_seqs, max_num_blocks_per_seq), |
| 179 | + dtype=torch.int32) |
| 180 | + start.record() |
| 181 | + flash_attn_varlen_func(queries[index], |
| 182 | + maybe_quantized_key_cache, |
| 183 | + maybe_quantized_value_cache, |
| 184 | + max_query_len, |
| 185 | + cu_query_lens, |
| 186 | + max_kv_len, |
| 187 | + seqused_k=seq_k, |
| 188 | + softmax_scale=scale, |
| 189 | + causal=False, |
| 190 | + block_table=block_tables, |
| 191 | + window_size=(-1, -1), |
| 192 | + s_aux=sink) |
| 193 | + end.record() |
| 194 | + end.synchronize() |
| 195 | + if index >= 5: # skip the first 5 iterations for warmup |
| 196 | + total_latency += start.elapsed_time(end) |
| 197 | + else: |
| 198 | + for index in range(iterations): |
| 199 | + block_tables = torch.randint(0, |
| 200 | + num_blocks, |
| 201 | + (num_seqs, max_num_blocks_per_seq), |
| 202 | + dtype=torch.int32) |
| 203 | + flash_attn_varlen_func_CalKernelTime(queries[index], |
| 204 | + maybe_quantized_key_cache, |
| 205 | + maybe_quantized_value_cache, |
| 206 | + max_query_len, |
| 207 | + cu_query_lens, |
| 208 | + max_kv_len, |
| 209 | + seqused_k=seq_k, |
| 210 | + softmax_scale=scale, |
| 211 | + causal=False, |
| 212 | + block_table=block_tables, |
| 213 | + window_size=(-1, -1), |
| 214 | + s_aux=sink, |
| 215 | + start_event=start, |
| 216 | + end_event=end) |
| 217 | + if index >= 5: # skip the first 5 iterations for warmup |
| 218 | + total_latency += start.elapsed_time(end) |
| 219 | + if provider == "flash_memBandwidth": |
| 220 | + torch.xpu.synchronize() |
| 221 | + ms = total_latency / (iterations - 5) |
| 222 | + memory_load_GB = calculate_memory_usage(seq_k.sum().item(), |
| 223 | + num_heads[1], head_size, |
| 224 | + output_dtype) |
| 225 | + clear_xpu_cache() |
| 226 | + return memory_load_GB / (ms / 1000) |
| 227 | + torch.xpu.synchronize() |
| 228 | + ms = total_latency / (iterations - 5) |
| 229 | + clear_xpu_cache() |
| 230 | + |
| 231 | + return 1000 * ms |
| 232 | + |
| 233 | + |
| 234 | +def get_benchmark_decode_with_paged_kv(iterations=20): |
| 235 | + |
| 236 | + @triton.testing.perf_report( |
| 237 | + triton.testing.Benchmark( |
| 238 | + x_names=[ |
| 239 | + "seq_lens", "num_heads", "head_size", "block_size", |
| 240 | + "output_dtype", "soft_cap", "num_blocks", "fa_versions", |
| 241 | + "q_dtype", "is_sink" |
| 242 | + ], |
| 243 | + x_vals=[tuple(c) for c in configs], |
| 244 | + line_arg="provider", |
| 245 | + line_vals=["flash", "flash_kernelTime", "flash_memBandwidth"], |
| 246 | + line_names=[ |
| 247 | + "FlashAttention(us)", "FlashAttention_kernelTime(us)", |
| 248 | + "FlashAttention_memBandwidth(GB/s)" |
| 249 | + ], |
| 250 | + styles=[("blue", "-"), ("green", "-"), ("purple", "-")], |
| 251 | + ylabel="Latency (us)", |
| 252 | + plot_name="flash-attn-decode", |
| 253 | + args={}, |
| 254 | + )) |
| 255 | + def benchmark(seq_lens, num_heads, head_size, block_size, output_dtype, |
| 256 | + soft_cap, num_blocks, fa_versions, q_dtype, is_sink, |
| 257 | + provider): |
| 258 | + return benchmark_decode_with_paged_kv(seq_lens=seq_lens, |
| 259 | + num_heads=num_heads, |
| 260 | + head_size=head_size, |
| 261 | + block_size=block_size, |
| 262 | + output_dtype=output_dtype, |
| 263 | + soft_cap=soft_cap, |
| 264 | + num_blocks=num_blocks, |
| 265 | + fa_versions=fa_versions, |
| 266 | + q_dtype=q_dtype, |
| 267 | + is_sink=is_sink, |
| 268 | + provider=provider, |
| 269 | + iterations=iterations) |
| 270 | + |
| 271 | + return benchmark |
| 272 | + |
| 273 | + |
| 274 | +def filter_configs(configs): |
| 275 | + new_configs = [] |
| 276 | + for config in configs: |
| 277 | + if (config[1] == (16, 1) and config[2] == 256) or \ |
| 278 | + (config[3] == 128 and config[6] == 32768 and config[2] >= 192): |
| 279 | + print("Skipping config due to potential OOM: ", config) |
| 280 | + continue |
| 281 | + new_configs.append(config) |
| 282 | + return new_configs |
| 283 | + |
| 284 | + |
| 285 | +if __name__ == "__main__": |
| 286 | + |
| 287 | + args = parse_args() |
| 288 | + seed = 1234 |
| 289 | + seed_everything(seed) |
| 290 | + iterations = 20 |
| 291 | + torch.set_default_device("xpu") |
| 292 | + torch.xpu.set_device("xpu:0") |
| 293 | + |
| 294 | + configs = gen_correctness_config() |
| 295 | + configs = filter_configs(configs) |
| 296 | + for config in configs: |
| 297 | + try: |
| 298 | + calculate_diff_decode_paged_kv(config) |
| 299 | + except Exception as e: |
| 300 | + print("Error in config: ", config, " error: ", e) |
| 301 | + clear_xpu_cache() |
| 302 | + |
| 303 | + configs = gen_perf_configs() |
| 304 | + configs = filter_configs(configs) |
| 305 | + benchmark = get_benchmark_decode_with_paged_kv(iterations=iterations) |
| 306 | + # Run performance benchmark |
| 307 | + benchmark.run(print_data=True, save_path=args.save_path) |
0 commit comments