@@ -14,19 +14,15 @@ def _reference_attention(query, key, value, causal=False):
1414 """Reference implementation using PyTorch SDPA."""
1515 query , key , value = (x .transpose (1 , 2 ).contiguous () for x in (query , key , value ))
1616 with torch .nn .attention .sdpa_kernel (torch .nn .attention .SDPBackend .MATH ):
17- out = torch .nn .functional .scaled_dot_product_attention (
18- query , key , value , is_causal = causal
19- )
17+ out = torch .nn .functional .scaled_dot_product_attention (query , key , value , is_causal = causal )
2018 return out .transpose (1 , 2 ).contiguous ()
2119
2220
2321def _varlen_reference_attention (q , k , v , cu_seqlens_q , cu_seqlens_k , causal = False ):
2422 """Reference implementation for variable length attention."""
2523 batch_size = cu_seqlens_q .shape [0 ] - 1
2624 total_tokens_q = q .shape [0 ]
27- out = torch .zeros (
28- (total_tokens_q , q .shape [1 ], q .shape [2 ]), device = q .device , dtype = q .dtype
29- )
25+ out = torch .zeros ((total_tokens_q , q .shape [1 ], q .shape [2 ]), device = q .device , dtype = q .dtype )
3026
3127 for b in range (batch_size ):
3228 start_q , end_q = cu_seqlens_q [b ], cu_seqlens_q [b + 1 ]
@@ -54,9 +50,7 @@ def setup_small(self):
5450 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
5551
5652 def benchmark_small (self ):
57- self .out = _extract_output (
58- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False )
59- )
53+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False ))
6054
6155 def verify_small (self ) -> torch .Tensor :
6256 return _reference_attention (self .q , self .k , self .v , causal = False )
@@ -70,9 +64,7 @@ def setup_medium(self):
7064 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
7165
7266 def benchmark_medium (self ):
73- self .out = _extract_output (
74- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False )
75- )
67+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False ))
7668
7769 def verify_medium (self ) -> torch .Tensor :
7870 return _reference_attention (self .q , self .k , self .v , causal = False )
@@ -86,9 +78,7 @@ def setup_large(self):
8678 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
8779
8880 def benchmark_large (self ):
89- self .out = _extract_output (
90- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False )
91- )
81+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = False ))
9282
9383 def verify_large (self ) -> torch .Tensor :
9484 return _reference_attention (self .q , self .k , self .v , causal = False )
@@ -106,9 +96,7 @@ def setup_small(self):
10696 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
10797
10898 def benchmark_small (self ):
109- self .out = _extract_output (
110- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True )
111- )
99+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True ))
112100
113101 def verify_small (self ) -> torch .Tensor :
114102 return _reference_attention (self .q , self .k , self .v , causal = True )
@@ -122,9 +110,7 @@ def setup_medium(self):
122110 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
123111
124112 def benchmark_medium (self ):
125- self .out = _extract_output (
126- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True )
127- )
113+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True ))
128114
129115 def verify_medium (self ) -> torch .Tensor :
130116 return _reference_attention (self .q , self .k , self .v , causal = True )
@@ -138,9 +124,7 @@ def setup_large(self):
138124 self .out = torch .empty (B , S , H , D , device = "cuda" , dtype = torch .float16 )
139125
140126 def benchmark_large (self ):
141- self .out = _extract_output (
142- self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True )
143- )
127+ self .out = _extract_output (self .kernel .flash_attn_func (self .q , self .k , self .v , causal = True ))
144128
145129 def verify_large (self ) -> torch .Tensor :
146130 return _reference_attention (self .q , self .k , self .v , causal = True )
@@ -180,9 +164,7 @@ def benchmark_small(self):
180164 )
181165
182166 def verify_small (self ) -> torch .Tensor :
183- return _varlen_reference_attention (
184- self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False
185- )
167+ return _varlen_reference_attention (self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False )
186168
187169 # Workload: medium (5 sequences, max_seqlen=256)
188170 def setup_medium (self ):
@@ -214,9 +196,7 @@ def benchmark_medium(self):
214196 )
215197
216198 def verify_medium (self ) -> torch .Tensor :
217- return _varlen_reference_attention (
218- self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False
219- )
199+ return _varlen_reference_attention (self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False )
220200
221201 # Workload: large (8 sequences, max_seqlen=512)
222202 def setup_large (self ):
@@ -248,6 +228,4 @@ def benchmark_large(self):
248228 )
249229
250230 def verify_large (self ) -> torch .Tensor :
251- return _varlen_reference_attention (
252- self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False
253- )
231+ return _varlen_reference_attention (self .q , self .k , self .v , self .cu_seqlens , self .cu_seqlens , causal = False )
0 commit comments