@@ -42,9 +42,14 @@ def filter_sdpa_kernels(
4242 for kernel in sdpa_kernels :
4343 if kernel == SDPBackend .FLASH_ATTENTION and not can_use_flash_attention (params ):
4444 continue
45- elif kernel == SDPBackend .EFFICIENT_ATTENTION and not can_use_efficient_attention (params ):
45+ elif (
46+ kernel == SDPBackend .EFFICIENT_ATTENTION
47+ and not can_use_efficient_attention (params )
48+ ):
4649 continue
47- elif kernel == SDPBackend .CUDNN_ATTENTION and not can_use_cudnn_attention (params ):
50+ elif kernel == SDPBackend .CUDNN_ATTENTION and not can_use_cudnn_attention (
51+ params
52+ ):
4853 continue
4954 new_kernels .append (kernel )
5055 return new_kernels
@@ -202,11 +207,21 @@ def mask_slice_bool(
202207 q_per_kv = n_head // n_query_groups
203208 assert n_head == n_query_groups * q_per_kv and q_per_kv >= 1
204209 if q_per_kv > 1 :
205- token_positions = token_positions .unsqueeze (2 ).expand (
206- - 1 , - 1 , q_per_kv , - 1 ,
207- ).reshape (batch_size , n_head , - 1 )
210+ token_positions = (
211+ token_positions .unsqueeze (2 )
212+ .expand (
213+ - 1 ,
214+ - 1 ,
215+ q_per_kv ,
216+ - 1 ,
217+ )
218+ .reshape (batch_size , n_head , - 1 )
219+ )
208220 token_positions = token_positions .unsqueeze (2 ).expand (
209- - 1 , - 1 , num , - 1 ,
221+ - 1 ,
222+ - 1 ,
223+ num ,
224+ - 1 ,
210225 )
211226 kwargs = dict (device = token_positions .device , dtype = token_positions .dtype )
212227 bool_mask = (
@@ -276,7 +291,7 @@ def build_mask_slice(
276291
277292
278293# Maximum number of `float32` entries for `tmp_array` for GB
279- ENTRIES_PER_GB = 2 ** 28
294+ ENTRIES_PER_GB = 2 ** 28
280295
281296# Maximum size of `tmp_array` in GB
282297DEFAULT_TMP_ARRAY_LIMIT_GB = 3
@@ -324,7 +339,9 @@ def create_temp_array(
324339 else :
325340 tmp_len = tmp_array_max_num_entries // factor
326341 if tmp_len < 1 :
327- raise ValueError (f"batch_size={ batch_size } , n_head={ n_head } , kv_len={ kv_len } too large. Their product must be <= { tmp_array_max_num_entries } " )
342+ raise ValueError (
343+ f"batch_size={ batch_size } , n_head={ n_head } , kv_len={ kv_len } too large. Their product must be <= { tmp_array_max_num_entries } "
344+ )
328345 num_splits = int (math .ceil (q_len / tmp_len ))
329346 shape = (batch_size , n_head , tmp_len , kv_len )
330347 kwargs = dict (device = device , dtype = torch .float32 )
@@ -388,7 +405,10 @@ def sdpa_attention_weights(
388405 _ , n_query_groups , kv_len , _ = key .shape
389406 # Compute attention weights f(S)
390407 attention_compute_scores (
391- query = query , key = key , out = tmp_array , scale_factor = scale_factor ,
408+ query = query ,
409+ key = key ,
410+ out = tmp_array ,
411+ scale_factor = scale_factor ,
392412 )
393413 # Attention masking
394414 if token_positions is None :
@@ -422,17 +442,21 @@ def sample_token_positions(
422442) -> torch .Tensor :
423443 index_kwargs = dict (dtype = torch .int64 , device = device )
424444 token_positions = torch .zeros (
425- (batch_size , n_query_groups , kv_len ), ** index_kwargs ,
445+ (batch_size , n_query_groups , kv_len ),
446+ ** index_kwargs ,
426447 )
427448 for bs in range (batch_size ):
428449 for nq in range (n_query_groups ):
429450 token_positions [bs , nq , :] = torch .randperm (
430- input_pos , ** index_kwargs ,
451+ input_pos ,
452+ ** index_kwargs ,
431453 )[:kv_len ]
432454 # Ensure that `input_pos:(input_pos + q_len)` is present
433455 index = torch .randperm (kv_len , ** index_kwargs )[:q_len ]
434456 token_positions [bs , nq , index ] = torch .arange (
435- input_pos , input_pos + q_len , ** index_kwargs ,
457+ input_pos ,
458+ input_pos + q_len ,
459+ ** index_kwargs ,
436460 )
437461 return token_positions
438462
0 commit comments