@@ -1007,15 +1007,15 @@ def get_total_q_block_count_bwd(
10071007 batch_idx ,
10081008 head_idx ,
10091009 n_block ,
1010- subtile_factor : cutlass .Constexpr = 1 ,
1010+ q_subtile_factor : cutlass .Constexpr = 1 ,
10111011 m_block_max : int = 0 ,
10121012):
10131013 """Count total tile iterations for given n_block (KV tile) in backward."""
10141014 q_block_cnt , _ , full_block_cnt , _ , * _ = blocksparse_tensors
10151015 total = q_block_cnt [batch_idx , head_idx , n_block ]
10161016 if const_expr (full_block_cnt is not None ):
10171017 total = total + full_block_cnt [batch_idx , head_idx , n_block ]
1018- return total * subtile_factor
1018+ return total * q_subtile_factor
10191019
10201020
10211021@cute .jit
@@ -1050,7 +1050,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10501050 should_load_Q : cutlass .Constexpr ,
10511051 should_load_dO : cutlass .Constexpr ,
10521052 # Subtiling factor and bounds
1053- subtile_factor : cutlass .Constexpr = 1 ,
1053+ q_subtile_factor : cutlass .Constexpr = 1 ,
10541054 m_block_max : int = 0 ,
10551055):
10561056 """SM100 backward block sparse loading with subtiling.
@@ -1065,7 +1065,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10651065 curr_full_idx ,
10661066 loop_count ,
10671067 ) = get_block_sparse_iteration_info_bwd (
1068- blocksparse_tensors , batch_idx , head_idx , n_block , subtile_factor , m_block_max
1068+ blocksparse_tensors , batch_idx , head_idx , n_block , q_subtile_factor , m_block_max
10691069 )
10701070
10711071 for iter_idx in cutlass .range (loop_count , unroll = 1 ):
@@ -1075,7 +1075,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10751075 curr_q_idx ,
10761076 curr_full_cnt ,
10771077 curr_full_idx ,
1078- subtile_factor ,
1078+ q_subtile_factor ,
10791079 m_block_max ,
10801080 )
10811081 m_block_safe = m_block
@@ -1148,7 +1148,7 @@ def get_block_sparse_iteration_info_bwd(
11481148 batch_idx ,
11491149 head_idx ,
11501150 n_block ,
1151- subtile_factor : cutlass .Constexpr = 1 ,
1151+ q_subtile_factor : cutlass .Constexpr = 1 ,
11521152 m_block_max : int = 0 ,
11531153):
11541154 """Extract block-sparse iteration info for backward pass.
@@ -1169,7 +1169,7 @@ def get_block_sparse_iteration_info_bwd(
11691169 sparse_block_count = curr_q_cnt
11701170 if const_expr (full_cnt is not None ):
11711171 sparse_block_count = sparse_block_count + curr_full_cnt
1172- total_count = sparse_block_count * subtile_factor
1172+ total_count = sparse_block_count * q_subtile_factor
11731173
11741174 return curr_q_cnt , curr_q_idx , curr_full_cnt , curr_full_idx , total_count
11751175
@@ -1181,7 +1181,7 @@ def get_m_block_from_iter_bwd(
11811181 curr_q_idx : cute .Tensor ,
11821182 curr_full_cnt ,
11831183 curr_full_idx : Optional [cute .Tensor ],
1184- subtile_factor : cutlass .Constexpr = 1 ,
1184+ q_subtile_factor : cutlass .Constexpr = 1 ,
11851185 m_block_max : int = 0 ,
11861186):
11871187 """Derive m_block index and is_full_block flag from iteration index.
@@ -1190,8 +1190,8 @@ def get_m_block_from_iter_bwd(
11901190 - m_block: The actual Q-tile block index
11911191 - is_full_block: True if this is a full block (no mask_mod needed)
11921192 """
1193- sparse_iter_idx = iter_idx // subtile_factor
1194- subtile_offset = iter_idx % subtile_factor
1193+ sparse_iter_idx = iter_idx // q_subtile_factor
1194+ subtile_offset = iter_idx % q_subtile_factor
11951195
11961196 sparse_m_block = Int32 (0 )
11971197 is_full_block = False
@@ -1204,7 +1204,7 @@ def get_m_block_from_iter_bwd(
12041204 else :
12051205 sparse_m_block = curr_q_idx [sparse_iter_idx ]
12061206
1207- return sparse_m_block * subtile_factor + subtile_offset , is_full_block
1207+ return sparse_m_block * q_subtile_factor + subtile_offset , is_full_block
12081208
12091209
12101210@cute .jit
@@ -1269,7 +1269,7 @@ def produce_block_sparse_q_loads_bwd_sm90(
12691269 tma_copy_bytes_K ,
12701270 tma_copy_bytes_V ,
12711271 Q_stage_eq_dO_stage : cutlass .Constexpr ,
1272- subtile_factor : cutlass .Constexpr ,
1272+ q_subtile_factor : cutlass .Constexpr ,
12731273 m_block_max : int ,
12741274):
12751275 """SM90 backward block sparse loading with separate partial/full loops.
@@ -1292,10 +1292,10 @@ def produce_block_sparse_q_loads_bwd_sm90(
12921292
12931293 kv_loaded = False
12941294
1295- for iter_idx in cutlass .range (curr_q_cnt * subtile_factor , unroll = 1 ):
1296- sparse_idx = iter_idx // subtile_factor
1297- subtile_offset = iter_idx % subtile_factor
1298- m_block = curr_q_idx [sparse_idx ] * subtile_factor + subtile_offset
1295+ for iter_idx in cutlass .range (curr_q_cnt * q_subtile_factor , unroll = 1 ):
1296+ sparse_idx = iter_idx // q_subtile_factor
1297+ subtile_offset = iter_idx % q_subtile_factor
1298+ m_block = curr_q_idx [sparse_idx ] * q_subtile_factor + subtile_offset
12991299
13001300 if m_block < m_block_max :
13011301 producer_state_Q , producer_state_dO = _load_q_do_block_sm90 (
@@ -1318,10 +1318,10 @@ def produce_block_sparse_q_loads_bwd_sm90(
13181318 kv_loaded = True
13191319
13201320 if const_expr (full_cnt is not None ):
1321- for iter_idx in cutlass .range (curr_full_cnt * subtile_factor , unroll = 1 ):
1322- sparse_idx = iter_idx // subtile_factor
1323- subtile_offset = iter_idx % subtile_factor
1324- m_block = curr_full_idx [sparse_idx ] * subtile_factor + subtile_offset
1321+ for iter_idx in cutlass .range (curr_full_cnt * q_subtile_factor , unroll = 1 ):
1322+ sparse_idx = iter_idx // q_subtile_factor
1323+ subtile_offset = iter_idx % q_subtile_factor
1324+ m_block = curr_full_idx [sparse_idx ] * q_subtile_factor + subtile_offset
13251325
13261326 if m_block < m_block_max :
13271327 producer_state_Q , producer_state_dO = _load_q_do_block_sm90 (
@@ -1362,7 +1362,7 @@ def consume_block_sparse_mma_bwd_sm90(
13621362 thr_mma_SdP ,
13631363 score_mod_fn = None ,
13641364 score_mod_bwd_fn = None ,
1365- subtile_factor : cutlass .Constexpr = 1 ,
1365+ q_subtile_factor : cutlass .Constexpr = 1 ,
13661366 m_block_max : int = 0 ,
13671367 aux_data : AuxData = AuxData (),
13681368 fastdiv_mods = (None , None ),
@@ -1414,10 +1414,10 @@ def consume_block_sparse_mma_bwd_sm90(
14141414 fastdiv_mods = fastdiv_mods ,
14151415 )
14161416
1417- for iter_idx in cutlass .range (curr_q_cnt * subtile_factor , unroll = 1 ):
1418- sparse_idx = iter_idx // subtile_factor
1419- subtile_offset = iter_idx % subtile_factor
1420- m_block = curr_q_idx [sparse_idx ] * subtile_factor + subtile_offset
1417+ for iter_idx in cutlass .range (curr_q_cnt * q_subtile_factor , unroll = 1 ):
1418+ sparse_idx = iter_idx // q_subtile_factor
1419+ subtile_offset = iter_idx % q_subtile_factor
1420+ m_block = curr_q_idx [sparse_idx ] * q_subtile_factor + subtile_offset
14211421
14221422 if m_block < m_block_max :
14231423 consumer_state_Q , consumer_state_dO = mma_one_m_block_fn (
@@ -1432,10 +1432,10 @@ def consume_block_sparse_mma_bwd_sm90(
14321432 dKV_accumulate = True
14331433
14341434 if const_expr (full_cnt is not None ):
1435- for iter_idx in cutlass .range (curr_full_cnt * subtile_factor , unroll = 1 ):
1436- sparse_idx = iter_idx // subtile_factor
1437- subtile_offset = iter_idx % subtile_factor
1438- m_block = curr_full_idx [sparse_idx ] * subtile_factor + subtile_offset
1435+ for iter_idx in cutlass .range (curr_full_cnt * q_subtile_factor , unroll = 1 ):
1436+ sparse_idx = iter_idx // q_subtile_factor
1437+ subtile_offset = iter_idx % q_subtile_factor
1438+ m_block = curr_full_idx [sparse_idx ] * q_subtile_factor + subtile_offset
14391439
14401440 if m_block < m_block_max :
14411441 consumer_state_Q , consumer_state_dO = mma_one_m_block_fn (
@@ -1490,7 +1490,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
14901490 n_block ,
14911491 sdQaccum : cute .Tensor ,
14921492 gdQaccum : cute .Tensor ,
1493- subtile_factor : cutlass .Constexpr ,
1493+ q_subtile_factor : cutlass .Constexpr ,
14941494 m_block_max : int ,
14951495 num_dQ_warp_groups : cutlass .Constexpr ,
14961496 num_threads_per_warp_group : cutlass .Constexpr ,
@@ -1511,10 +1511,10 @@ def dQaccum_store_block_sparse_bwd_sm90(
15111511 curr_full_cnt = Int32 (0 )
15121512 curr_full_idx = None
15131513
1514- for iter_idx in cutlass .range (curr_q_cnt * subtile_factor , unroll = 1 ):
1515- sparse_idx = iter_idx // subtile_factor
1516- subtile_offset = iter_idx % subtile_factor
1517- m_block = curr_q_idx [sparse_idx ] * subtile_factor + subtile_offset
1514+ for iter_idx in cutlass .range (curr_q_cnt * q_subtile_factor , unroll = 1 ):
1515+ sparse_idx = iter_idx // q_subtile_factor
1516+ subtile_offset = iter_idx % q_subtile_factor
1517+ m_block = curr_q_idx [sparse_idx ] * q_subtile_factor + subtile_offset
15181518
15191519 if m_block < m_block_max :
15201520 _store_one_dQaccum_sm90 (
@@ -1527,10 +1527,10 @@ def dQaccum_store_block_sparse_bwd_sm90(
15271527 )
15281528
15291529 if const_expr (full_cnt is not None ):
1530- for iter_idx in cutlass .range (curr_full_cnt * subtile_factor , unroll = 1 ):
1531- sparse_idx = iter_idx // subtile_factor
1532- subtile_offset = iter_idx % subtile_factor
1533- m_block = curr_full_idx [sparse_idx ] * subtile_factor + subtile_offset
1530+ for iter_idx in cutlass .range (curr_full_cnt * q_subtile_factor , unroll = 1 ):
1531+ sparse_idx = iter_idx // q_subtile_factor
1532+ subtile_offset = iter_idx % q_subtile_factor
1533+ m_block = curr_full_idx [sparse_idx ] * q_subtile_factor + subtile_offset
15341534
15351535 if m_block < m_block_max :
15361536 _store_one_dQaccum_sm90 (
0 commit comments