@@ -42,7 +42,7 @@ class MetaData():
4242 rotary_cos : Optional [torch .Tensor ] = None
4343 rotary_interleaved : bool = False
4444 rotary_conjunction : bool = False
45-
45+
4646
4747 def __repr__ (self ) -> str :
4848 return (f"MetaData(\n "
@@ -161,7 +161,7 @@ def generate_varlen_tensor(
161161 if batch_size is None :
162162 valid_batch_sizes = [bs for bs in [1 , 2 , 4 , 8 , 16 , 32 , 64 ] if bs <= total_seqlen ]
163163 batch_size = random .choice (valid_batch_sizes )
164-
164+
165165 # get seqlens
166166 if equal_seqlens :
167167 seqlens = torch .full (
@@ -241,14 +241,14 @@ def input_helper(
241241 TOTAL_SEQLENS_Q = BATCH * N_CTX_Q
242242 TOTAL_SEQLENS_K = BATCH * N_CTX_K
243243 equal_seqlens = False
244-
244+
245245 # gen tensors
246246 # TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen
247247 q , cu_seqlens_q , max_seqlen_q = generate_varlen_tensor (TOTAL_SEQLENS_Q , HQ , D_HEAD , batch_size = BATCH , dtype = dtype , device = device , equal_seqlens = equal_seqlens , DEBUG_INPUT = DEBUG_INPUT )
248248 k , cu_seqlens_k , max_seqlen_k = generate_varlen_tensor (TOTAL_SEQLENS_K , HK , D_HEAD , batch_size = BATCH , dtype = dtype , device = device , equal_seqlens = equal_seqlens , DEBUG_INPUT = DEBUG_INPUT )
249249 v , _ , _ = generate_varlen_tensor (TOTAL_SEQLENS_K , HK , D_HEAD , batch_size = BATCH , dtype = dtype , device = device , equal_seqlens = equal_seqlens , DEBUG_INPUT = DEBUG_INPUT )
250250 do = torch .ones_like (q ) if DEBUG_INPUT else torch .randn_like (q )
251-
251+
252252 # setup metadata
253253 if DEBUG_INPUT :
254254 sm_scale = 1
@@ -313,7 +313,7 @@ def input_helper(
313313
314314 return qkv , do , metadata
315315 else :
316- assert False , f"Unsupported packing mode: { packing } "
316+ raise AssertionError ( f"Unsupported packing mode: { packing } " )
317317
318318# -------------------------------
319319# Alibi
@@ -366,21 +366,21 @@ def get_shape_from_layout(
366366 elif layout == 'thd' :
367367 total_seqlen , num_heads , head_dim = x .shape
368368 if cu_seqlens is None :
369- raise ValueError ("cu_seqlens must be provided for varlen (thd) layout" )
369+ raise ValueError ("cu_seqlens must be provided for varlen (thd) layout" )
370370 if max_seqlen is None :
371371 raise ValueError ("max_seqlen must be provided for varlen (thd) layout" )
372-
372+
373373 batch , max_seqlen_final , num_heads , head_dim = len (cu_seqlens ) - 1 , max_seqlen , num_heads , head_dim
374374 else :
375- assert False , "Got unsupported layout."
375+ raise AssertionError ( "Got unsupported layout." )
376376
377377 return batch , max_seqlen_final , num_heads , head_dim
378378
379379
380380def get_shapes_from_layout (q , k , layout , cu_seqlens_q = None , cu_seqlens_k = None , max_seqlen_q = None , max_seqlen_k = None ):
381381 batch_q , seqlen_q , nheads_q , head_size_q = get_shape_from_layout (q , layout , cu_seqlens_q , max_seqlen_q )
382382 batch_k , seqlen_k , nheads_k , head_size_k = get_shape_from_layout (k , layout , cu_seqlens_k , max_seqlen_k )
383-
383+
384384 # assert
385385 assert batch_q == batch_k
386386 assert head_size_q == head_size_k
@@ -389,13 +389,13 @@ def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = Non
389389
390390def get_stride_from_layout (x : torch .Tensor , layout :Literal ["bshd" , "bhsd" , "thd" ]):
391391 if layout == 'thd' :
392- strides = (0 , x .stride (1 ), x .stride (0 ), x .stride (2 ))
392+ strides = (0 , x .stride (1 ), x .stride (0 ), x .stride (2 ))
393393 elif layout == 'bhsd' :
394394 strides = (x .stride (0 ), x .stride (1 ), x .stride (2 ), x .stride (3 ))
395395 elif layout == 'bshd' :
396396 strides = (x .stride (0 ), x .stride (2 ), x .stride (1 ), x .stride (3 ))
397397 else :
398- assert False , 'Got unsupported layout.'
398+ raise AssertionError ( 'Got unsupported layout.' )
399399 return strides
400400
401401def get_shape_and_strides_from_layout (x : torch .Tensor , layout : Literal ["bshd" , "bhsd" , "thd" ], cu_seqlens : Optional [torch .Tensor ] = None , max_seqlen : Optional [int ] = None ):
@@ -458,22 +458,22 @@ def write_dropout_mask(x, tensor_name = "tensor"):
458458 if True :
459459 BLOCK_M = 64
460460 BLOCK_N = 64
461-
461+
462462 # Calculate number of blocks in each dimension
463463 m_blocks = math .ceil (seqlen_m / BLOCK_M )
464464 n_blocks = math .ceil (seqlen_n / BLOCK_N )
465-
465+
466466 # Process each block
467467 for m_block in range (m_blocks ):
468468 # Calculate row range for current block
469469 row_start = m_block * BLOCK_M
470470 row_end = min (row_start + BLOCK_M , seqlen_m )
471-
471+
472472 for n_block in range (n_blocks ):
473473 # Calculate column range for current block
474474 col_start = n_block * BLOCK_N
475475 col_end = min (col_start + BLOCK_N , seqlen_n )
476-
476+
477477 # Extract and write the current block
478478 for row_idx in range (row_start , row_end ):
479479 row_data = dropout_mask [row_idx ][col_start :col_end ]
0 commit comments