11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43"""Tests for merge_attn_states function.
54
65Run `pytest tests/test_merge_attn_states.py`.
76"""
87
8+ import logging
9+
910import pytest
1011import torch
11- import logging
1212
1313from tests .register_ops import merge_attn_states as merge_attn_states_xpu
1414
1515logger = logging .getLogger ("vllm_xpu_kernel" )
1616
1717
18-
1918# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
2019# can be used to combine partial attention results (in the split-KV case)
2120def merge_attn_states_torch (
22- output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
23- prefix_output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
24- prefix_lse : torch .Tensor , # [NUM_HEADS, NUM_TOKENS]
25- suffix_output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
26- suffix_lse : torch .Tensor , # [NUM_HEADS, NUM_TOKENS]
27- output_lse : torch .Tensor | None = None , # [NUM_HEADS, NUM_TOKENS]
21+ output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
22+ prefix_output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
23+ prefix_lse : torch .Tensor , # [NUM_HEADS, NUM_TOKENS]
24+ suffix_output : torch .Tensor , # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
25+ suffix_lse : torch .Tensor , # [NUM_HEADS, NUM_TOKENS]
26+ output_lse : torch .Tensor | None = None , # [NUM_HEADS, NUM_TOKENS]
2827):
2928 p_lse = prefix_lse
3029 s_lse = suffix_lse
@@ -42,8 +41,10 @@ def merge_attn_states_torch(
4241 output_lse = torch .log (out_se ) + max_lse
4342 p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
4443 s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
45- p_scale = torch .transpose (p_scale , 0 , 1 ).unsqueeze (2 ) # [NUM_TOKENS, NUM_HEADS, 1]
46- s_scale = torch .transpose (s_scale , 0 , 1 ).unsqueeze (2 ) # [NUM_TOKENS, NUM_HEADS, 1]
44+ p_scale = torch .transpose (p_scale , 0 ,
45+ 1 ).unsqueeze (2 ) # [NUM_TOKENS, NUM_HEADS, 1]
46+ s_scale = torch .transpose (s_scale , 0 ,
47+ 1 ).unsqueeze (2 ) # [NUM_TOKENS, NUM_HEADS, 1]
4748 output = prefix_output * p_scale + suffix_output * s_scale
4849 return output , output_lse
4950
@@ -66,13 +67,10 @@ def merge_attn_states_torch(
6667}
6768
6869
69-
7070def generate_markdown_table ():
7171 global all_case_info
72- table_header = (
73- "| tokens | heads | headsize | dtype "
74- "| device | torch | cuda | speedup |"
75- )
72+ table_header = ("| tokens | heads | headsize | dtype "
73+ "| device | torch | cuda | speedup |" )
7674 table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- |"
7775
7876 def shortly_dtype (dtype : torch .dtype ) -> str :
@@ -96,36 +94,33 @@ def shortly_device(device: str) -> str:
9694 ) = info
9795 dtype = shortly_dtype (dtype )
9896 device = shortly_device (device )
99- print (
100- f"| { num_tokens } | { num_heads } | { head_size } "
101- f"| { dtype } | { device } | { avg_time_torch_kernel :.5f} ms "
102- f"| { avg_time_xpu_kernel :.5f} ms "
103- f"| { performance_improved :.4f} x |"
104- )
97+ print (f"| { num_tokens } | { num_heads } | { head_size } "
98+ f"| { dtype } | { device } | { avg_time_torch_kernel :.5f} ms "
99+ f"| { avg_time_xpu_kernel :.5f} ms "
100+ f"| { performance_improved :.4f} x |" )
105101
106102
107103@pytest .mark .parametrize ("num_tokens" , NUM_BATCH_TOKENS )
108104@pytest .mark .parametrize ("num_query_heads" , NUM_QUERY_HEADS )
109105@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
110106@pytest .mark .parametrize ("output_dtype" , DTYPES )
111107@torch .inference_mode ()
112- def test_merge_attn_states (
113- num_tokens : int , num_query_heads : int , head_size : int , output_dtype : torch .dtype
114- ):
108+ def test_merge_attn_states (num_tokens : int , num_query_heads : int ,
109+ head_size : int , output_dtype : torch .dtype ):
115110
116111 NUM_TOKENS = num_tokens
117112 NUM_HEADS = num_query_heads
118113 HEAD_SIZE = head_size
119114
120- logger .debug (
121- f"\n NUM_TOKENS:{ NUM_TOKENS } , NUM_HEADS:{ NUM_HEADS } , "
122- f"HEAD_SIZE:{ HEAD_SIZE } , DTYPE: { output_dtype } , "
123- f"Device: xpu."
124- )
125-
126115 # prefix_lse and suffix_lse contain inf and normal values
127- prefix_lse = torch .randn (NUM_HEADS , NUM_TOKENS , dtype = torch .float32 , device = "xpu" )
128- suffix_lse = torch .randn (NUM_HEADS , NUM_TOKENS , dtype = torch .float32 , device = "xpu" )
116+ prefix_lse = torch .randn (NUM_HEADS ,
117+ NUM_TOKENS ,
118+ dtype = torch .float32 ,
119+ device = "xpu" )
120+ suffix_lse = torch .randn (NUM_HEADS ,
121+ NUM_TOKENS ,
122+ dtype = torch .float32 ,
123+ device = "xpu" )
129124
130125 # Generate boolean masks
131126 mask_prefix = torch .rand (NUM_HEADS , NUM_TOKENS ) < 0.1
@@ -140,18 +135,18 @@ def test_merge_attn_states(
140135
141136 # Other input tensors (need to be initialized but
142137 # no actual calculation needed)
143- output = torch .zeros (
144- ( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ), dtype = output_dtype , device = "xpu"
145- )
146- output_lse = torch .zeros (
147- ( NUM_HEADS , NUM_TOKENS ), dtype = torch .float32 , device = "xpu"
148- )
149- prefix_output = torch .randn (
150- ( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ), dtype = output_dtype , device = "xpu"
151- )
152- suffix_output = torch .randn (
153- ( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ), dtype = output_dtype , device = "xpu"
154- )
138+ output = torch .zeros (( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ),
139+ dtype = output_dtype ,
140+ device = "xpu" )
141+ output_lse = torch .zeros (( NUM_HEADS , NUM_TOKENS ),
142+ dtype = torch .float32 ,
143+ device = "xpu" )
144+ prefix_output = torch .randn (( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ),
145+ dtype = output_dtype ,
146+ device = "xpu" )
147+ suffix_output = torch .randn (( NUM_TOKENS , NUM_HEADS , HEAD_SIZE ),
148+ dtype = output_dtype ,
149+ device = "xpu" )
155150
156151 warmup_times = 2
157152 repeat_times = 20
@@ -226,60 +221,39 @@ def test_merge_attn_states(
226221
227222 # 2. Performance compare
228223 performance_improved = avg_time_torch_kernel / avg_time_xpu_kernel
229- logger .debug (f" Torch time: { avg_time_torch_kernel :.6f} ms" )
230- logger .debug (
231- f" XPU time: { avg_time_xpu_kernel :.6f} ms, "
232- f"Performance: { performance_improved :.5f} x"
233- )
234- logger .debug ("-" * 100 )
235-
236- # 4. Correctness compare
224+ # print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
225+ # print(f" XPU time: {avg_time_xpu_kernel:.6f}ms, "
226+ # f"Performance: {performance_improved:.5f}x")
227+ # print("-" * 100)
228+
229+ # 3. Correctness compare
237230 # Liger Kernel: Efficient Triton Kernels for LLM Training
238231 # https://arxiv.org/pdf/2410.10989, 3.3 Correctness
239232 # use rtol = 1e-2 for bfloat16.
240233 rtol = 1e-2 if output_dtype == torch .bfloat16 else 1e-3
241234
242- def diff (a : torch .Tensor , b : torch .Tensor ):
243- max_diff = torch .max (torch .abs (a .float () - b .float ()))
244- return max_diff
245-
246- # Use Triton output as reference because we want to replace
247- # the Triton kernel with custom XPU kernel for merge attn
248- # states operation.
249- torch .testing .assert_close (
250- output_xpu .float (), output_torch .float (), atol = 1e-3 , rtol = rtol
251- )
252- logger .debug ("Output all match, max abs diff:" )
253- logger .debug (f" (XPU vs Torch) : { diff (output_torch , output_xpu )} " )
254- logger .debug ("-" * 100 )
255-
256- torch .testing .assert_close (
257- output_lse_xpu .float (), output_lse_torch .float (), atol = 1e-3 , rtol = rtol
258- )
259- logger .debug ("Output LSE all match, max abs diff:" )
260- logger .debug (f" (XPU vs Torch) : { diff (output_lse_torch , output_lse_xpu )} " )
261- logger .debug ("-" * 100 )
262-
263- logger .debug (
264- "All output values test passed! All inf values "
265- "are correctly replaced with -inf."
266- )
267- logger .debug ("-" * 100 )
235+ # Use torch output as reference
236+ torch .testing .assert_close (output_xpu .float (),
237+ output_torch .float (),
238+ atol = 1e-3 ,
239+ rtol = rtol )
240+
241+ torch .testing .assert_close (output_lse_xpu .float (),
242+ output_lse_torch .float (),
243+ atol = 1e-3 ,
244+ rtol = rtol )
268245
269246 device = "xpu"
270- all_case_info .append (
271- (
272- NUM_TOKENS ,
273- NUM_HEADS ,
274- HEAD_SIZE ,
275- output_dtype ,
276- device ,
277- avg_time_torch_kernel ,
278- avg_time_xpu_kernel ,
279- performance_improved ,
280- )
281- )
282- if len (all_case_info ) == (
283- len (NUM_BATCH_TOKENS ) * len (HEAD_SIZES ) * len (NUM_QUERY_HEADS ) * len (DTYPES )
284- ):
247+ all_case_info .append ((
248+ NUM_TOKENS ,
249+ NUM_HEADS ,
250+ HEAD_SIZE ,
251+ output_dtype ,
252+ device ,
253+ avg_time_torch_kernel ,
254+ avg_time_xpu_kernel ,
255+ performance_improved ,
256+ ))
257+ if len (all_case_info ) == (len (NUM_BATCH_TOKENS ) * len (HEAD_SIZES ) *
258+ len (NUM_QUERY_HEADS ) * len (DTYPES )):
285259 generate_markdown_table ()
0 commit comments