Skip to content

Commit d16e381

Browse files
authored
Make q_subtile_factor default to identity (#2660)
1 parent cbbab83 commit d16e381

11 files changed

Lines changed: 97 additions & 89 deletions

flash_attn/cute/block_sparse_utils.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,15 +1007,15 @@ def get_total_q_block_count_bwd(
10071007
batch_idx,
10081008
head_idx,
10091009
n_block,
1010-
subtile_factor: cutlass.Constexpr = 1,
1010+
q_subtile_factor: cutlass.Constexpr = 1,
10111011
m_block_max: int = 0,
10121012
):
10131013
"""Count total tile iterations for given n_block (KV tile) in backward."""
10141014
q_block_cnt, _, full_block_cnt, _, *_ = blocksparse_tensors
10151015
total = q_block_cnt[batch_idx, head_idx, n_block]
10161016
if const_expr(full_block_cnt is not None):
10171017
total = total + full_block_cnt[batch_idx, head_idx, n_block]
1018-
return total * subtile_factor
1018+
return total * q_subtile_factor
10191019

10201020

10211021
@cute.jit
@@ -1050,7 +1050,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10501050
should_load_Q: cutlass.Constexpr,
10511051
should_load_dO: cutlass.Constexpr,
10521052
# Subtiling factor and bounds
1053-
subtile_factor: cutlass.Constexpr = 1,
1053+
q_subtile_factor: cutlass.Constexpr = 1,
10541054
m_block_max: int = 0,
10551055
):
10561056
"""SM100 backward block sparse loading with subtiling.
@@ -1065,7 +1065,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10651065
curr_full_idx,
10661066
loop_count,
10671067
) = get_block_sparse_iteration_info_bwd(
1068-
blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
1068+
blocksparse_tensors, batch_idx, head_idx, n_block, q_subtile_factor, m_block_max
10691069
)
10701070

10711071
for iter_idx in cutlass.range(loop_count, unroll=1):
@@ -1075,7 +1075,7 @@ def produce_block_sparse_q_loads_bwd_sm100(
10751075
curr_q_idx,
10761076
curr_full_cnt,
10771077
curr_full_idx,
1078-
subtile_factor,
1078+
q_subtile_factor,
10791079
m_block_max,
10801080
)
10811081
m_block_safe = m_block
@@ -1148,7 +1148,7 @@ def get_block_sparse_iteration_info_bwd(
11481148
batch_idx,
11491149
head_idx,
11501150
n_block,
1151-
subtile_factor: cutlass.Constexpr = 1,
1151+
q_subtile_factor: cutlass.Constexpr = 1,
11521152
m_block_max: int = 0,
11531153
):
11541154
"""Extract block-sparse iteration info for backward pass.
@@ -1169,7 +1169,7 @@ def get_block_sparse_iteration_info_bwd(
11691169
sparse_block_count = curr_q_cnt
11701170
if const_expr(full_cnt is not None):
11711171
sparse_block_count = sparse_block_count + curr_full_cnt
1172-
total_count = sparse_block_count * subtile_factor
1172+
total_count = sparse_block_count * q_subtile_factor
11731173

11741174
return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count
11751175

@@ -1181,7 +1181,7 @@ def get_m_block_from_iter_bwd(
11811181
curr_q_idx: cute.Tensor,
11821182
curr_full_cnt,
11831183
curr_full_idx: Optional[cute.Tensor],
1184-
subtile_factor: cutlass.Constexpr = 1,
1184+
q_subtile_factor: cutlass.Constexpr = 1,
11851185
m_block_max: int = 0,
11861186
):
11871187
"""Derive m_block index and is_full_block flag from iteration index.
@@ -1190,8 +1190,8 @@ def get_m_block_from_iter_bwd(
11901190
- m_block: The actual Q-tile block index
11911191
- is_full_block: True if this is a full block (no mask_mod needed)
11921192
"""
1193-
sparse_iter_idx = iter_idx // subtile_factor
1194-
subtile_offset = iter_idx % subtile_factor
1193+
sparse_iter_idx = iter_idx // q_subtile_factor
1194+
subtile_offset = iter_idx % q_subtile_factor
11951195

11961196
sparse_m_block = Int32(0)
11971197
is_full_block = False
@@ -1204,7 +1204,7 @@ def get_m_block_from_iter_bwd(
12041204
else:
12051205
sparse_m_block = curr_q_idx[sparse_iter_idx]
12061206

1207-
return sparse_m_block * subtile_factor + subtile_offset, is_full_block
1207+
return sparse_m_block * q_subtile_factor + subtile_offset, is_full_block
12081208

12091209

12101210
@cute.jit
@@ -1269,7 +1269,7 @@ def produce_block_sparse_q_loads_bwd_sm90(
12691269
tma_copy_bytes_K,
12701270
tma_copy_bytes_V,
12711271
Q_stage_eq_dO_stage: cutlass.Constexpr,
1272-
subtile_factor: cutlass.Constexpr,
1272+
q_subtile_factor: cutlass.Constexpr,
12731273
m_block_max: int,
12741274
):
12751275
"""SM90 backward block sparse loading with separate partial/full loops.
@@ -1292,10 +1292,10 @@ def produce_block_sparse_q_loads_bwd_sm90(
12921292

12931293
kv_loaded = False
12941294

1295-
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1296-
sparse_idx = iter_idx // subtile_factor
1297-
subtile_offset = iter_idx % subtile_factor
1298-
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1295+
for iter_idx in cutlass.range(curr_q_cnt * q_subtile_factor, unroll=1):
1296+
sparse_idx = iter_idx // q_subtile_factor
1297+
subtile_offset = iter_idx % q_subtile_factor
1298+
m_block = curr_q_idx[sparse_idx] * q_subtile_factor + subtile_offset
12991299

13001300
if m_block < m_block_max:
13011301
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
@@ -1318,10 +1318,10 @@ def produce_block_sparse_q_loads_bwd_sm90(
13181318
kv_loaded = True
13191319

13201320
if const_expr(full_cnt is not None):
1321-
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1322-
sparse_idx = iter_idx // subtile_factor
1323-
subtile_offset = iter_idx % subtile_factor
1324-
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1321+
for iter_idx in cutlass.range(curr_full_cnt * q_subtile_factor, unroll=1):
1322+
sparse_idx = iter_idx // q_subtile_factor
1323+
subtile_offset = iter_idx % q_subtile_factor
1324+
m_block = curr_full_idx[sparse_idx] * q_subtile_factor + subtile_offset
13251325

13261326
if m_block < m_block_max:
13271327
producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
@@ -1362,7 +1362,7 @@ def consume_block_sparse_mma_bwd_sm90(
13621362
thr_mma_SdP,
13631363
score_mod_fn=None,
13641364
score_mod_bwd_fn=None,
1365-
subtile_factor: cutlass.Constexpr = 1,
1365+
q_subtile_factor: cutlass.Constexpr = 1,
13661366
m_block_max: int = 0,
13671367
aux_data: AuxData = AuxData(),
13681368
fastdiv_mods=(None, None),
@@ -1414,10 +1414,10 @@ def consume_block_sparse_mma_bwd_sm90(
14141414
fastdiv_mods=fastdiv_mods,
14151415
)
14161416

1417-
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1418-
sparse_idx = iter_idx // subtile_factor
1419-
subtile_offset = iter_idx % subtile_factor
1420-
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1417+
for iter_idx in cutlass.range(curr_q_cnt * q_subtile_factor, unroll=1):
1418+
sparse_idx = iter_idx // q_subtile_factor
1419+
subtile_offset = iter_idx % q_subtile_factor
1420+
m_block = curr_q_idx[sparse_idx] * q_subtile_factor + subtile_offset
14211421

14221422
if m_block < m_block_max:
14231423
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
@@ -1432,10 +1432,10 @@ def consume_block_sparse_mma_bwd_sm90(
14321432
dKV_accumulate = True
14331433

14341434
if const_expr(full_cnt is not None):
1435-
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1436-
sparse_idx = iter_idx // subtile_factor
1437-
subtile_offset = iter_idx % subtile_factor
1438-
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1435+
for iter_idx in cutlass.range(curr_full_cnt * q_subtile_factor, unroll=1):
1436+
sparse_idx = iter_idx // q_subtile_factor
1437+
subtile_offset = iter_idx % q_subtile_factor
1438+
m_block = curr_full_idx[sparse_idx] * q_subtile_factor + subtile_offset
14391439

14401440
if m_block < m_block_max:
14411441
consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
@@ -1490,7 +1490,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
14901490
n_block,
14911491
sdQaccum: cute.Tensor,
14921492
gdQaccum: cute.Tensor,
1493-
subtile_factor: cutlass.Constexpr,
1493+
q_subtile_factor: cutlass.Constexpr,
14941494
m_block_max: int,
14951495
num_dQ_warp_groups: cutlass.Constexpr,
14961496
num_threads_per_warp_group: cutlass.Constexpr,
@@ -1511,10 +1511,10 @@ def dQaccum_store_block_sparse_bwd_sm90(
15111511
curr_full_cnt = Int32(0)
15121512
curr_full_idx = None
15131513

1514-
for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1515-
sparse_idx = iter_idx // subtile_factor
1516-
subtile_offset = iter_idx % subtile_factor
1517-
m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1514+
for iter_idx in cutlass.range(curr_q_cnt * q_subtile_factor, unroll=1):
1515+
sparse_idx = iter_idx // q_subtile_factor
1516+
subtile_offset = iter_idx % q_subtile_factor
1517+
m_block = curr_q_idx[sparse_idx] * q_subtile_factor + subtile_offset
15181518

15191519
if m_block < m_block_max:
15201520
_store_one_dQaccum_sm90(
@@ -1527,10 +1527,10 @@ def dQaccum_store_block_sparse_bwd_sm90(
15271527
)
15281528

15291529
if const_expr(full_cnt is not None):
1530-
for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1531-
sparse_idx = iter_idx // subtile_factor
1532-
subtile_offset = iter_idx % subtile_factor
1533-
m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1530+
for iter_idx in cutlass.range(curr_full_cnt * q_subtile_factor, unroll=1):
1531+
sparse_idx = iter_idx // q_subtile_factor
1532+
subtile_offset = iter_idx % q_subtile_factor
1533+
m_block = curr_full_idx[sparse_idx] * q_subtile_factor + subtile_offset
15341534

15351535
if m_block < m_block_max:
15361536
_store_one_dQaccum_sm90(

flash_attn/cute/block_sparsity.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,15 @@ def get_block_sparse_expected_shapes_bwd(
391391
seqlen_k: int,
392392
m_block_size: int,
393393
n_block_size: int,
394-
subtile_factor: int,
394+
q_subtile_factor: int,
395395
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
396396
"""Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.
397397
398398
Backward uses Q-direction indexing (transposed from forward), where shapes are
399399
indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
400-
by subtile_factor * m_block_size.
400+
by q_subtile_factor * m_block_size.
401401
"""
402-
sparse_block_size_q = subtile_factor * m_block_size
402+
sparse_block_size_q = q_subtile_factor * m_block_size
403403
expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
404404
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
405405
expected_count_shape = (batch_size, num_head, expected_n_blocks)
@@ -590,17 +590,17 @@ def normalize_block_sparse_config_bwd(
590590
seqlen_q: int,
591591
seqlen_k: int,
592592
block_size: tuple[int, int],
593-
subtile_factor: int,
593+
q_subtile_factor: int,
594594
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None]:
595595
m_block_size, n_block_size = block_size
596596
if tensors.block_size is None:
597-
sparse_block_size_q, sparse_block_size_kv = subtile_factor * m_block_size, n_block_size
597+
sparse_block_size_q, sparse_block_size_kv = q_subtile_factor * m_block_size, n_block_size
598598
else:
599599
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
600-
if sparse_block_size_q != subtile_factor * m_block_size:
600+
if sparse_block_size_q != q_subtile_factor * m_block_size:
601601
raise ValueError(
602-
f"Block sparsity expects sparse_block_size_q={subtile_factor * m_block_size} "
603-
f"for subtile_factor={subtile_factor}."
602+
f"Block sparsity expects sparse_block_size_q={q_subtile_factor * m_block_size} "
603+
f"for q_subtile_factor={q_subtile_factor}."
604604
)
605605
if sparse_block_size_kv != n_block_size:
606606
raise ValueError(
@@ -613,7 +613,7 @@ def normalize_block_sparse_config_bwd(
613613
seqlen_k,
614614
m_block_size,
615615
n_block_size,
616-
subtile_factor,
616+
q_subtile_factor,
617617
)
618618
normalized_tensors = normalize_block_sparse_tensors(
619619
tensors,
@@ -623,7 +623,7 @@ def normalize_block_sparse_config_bwd(
623623
hint=lambda: (
624624
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, "
625625
f"and optionally full_q_cnt/full_q_idx). Regenerate the backward BlockMask with "
626-
f"BLOCK_SIZE=({subtile_factor * m_block_size}, {n_block_size})."
626+
f"BLOCK_SIZE=({q_subtile_factor * m_block_size}, {n_block_size})."
627627
),
628628
)
629629
return normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors)

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
score_mod_bwd: cutlass.Constexpr | None = None,
6767
mask_mod: cutlass.Constexpr | None = None,
6868
has_aux_tensors: cutlass.Constexpr = False,
69-
subtile_factor: cutlass.Constexpr[int] = 1,
69+
q_subtile_factor: cutlass.Constexpr[int] = 1,
7070
):
7171
# padding head_dim to a multiple of 16 as k_block_size
7272
hdim_multiple_of = 16
@@ -119,7 +119,7 @@ def __init__(
119119
self.score_mod_bwd = score_mod_bwd
120120
self.mask_mod = mask_mod
121121
self.has_aux_tensors = has_aux_tensors
122-
self.subtile_factor = subtile_factor
122+
self.q_subtile_factor = q_subtile_factor
123123
# For score_mod, use vec_size=1 (like forward) to handle per-element indices
124124
if cutlass.const_expr(has_aux_tensors):
125125
self.vec_size: cutlass.Constexpr = 1
@@ -1910,7 +1910,7 @@ def load(
19101910
batch_idx,
19111911
head_idx,
19121912
n_block,
1913-
subtile_factor=self.subtile_factor,
1913+
q_subtile_factor=self.q_subtile_factor,
19141914
m_block_max=m_block_max,
19151915
)
19161916
process_tile = total_m_block_cnt > Int32(0)
@@ -1947,7 +1947,7 @@ def load(
19471947
self.tma_copy_bytes["V"],
19481948
should_load_Q=should_load_Q,
19491949
should_load_dO=should_load_dO,
1950-
subtile_factor=self.subtile_factor,
1950+
q_subtile_factor=self.q_subtile_factor,
19511951
m_block_max=m_block_max,
19521952
)
19531953
)
@@ -2366,7 +2366,7 @@ def mma(
23662366
batch_idx,
23672367
head_idx,
23682368
n_block,
2369-
subtile_factor=self.subtile_factor,
2369+
q_subtile_factor=self.q_subtile_factor,
23702370
m_block_max=m_block_max,
23712371
)
23722372
process_tile = block_iter_count > Int32(0)
@@ -3019,7 +3019,7 @@ def compute_loop(
30193019
batch_idx,
30203020
head_idx,
30213021
n_block,
3022-
subtile_factor=self.subtile_factor,
3022+
q_subtile_factor=self.q_subtile_factor,
30233023
m_block_max=m_block_max,
30243024
)
30253025
process_tile = loop_count > Int32(0)
@@ -3038,7 +3038,7 @@ def compute_loop(
30383038
curr_q_idx,
30393039
curr_full_cnt,
30403040
curr_full_idx,
3041-
subtile_factor=self.subtile_factor,
3041+
q_subtile_factor=self.q_subtile_factor,
30423042
m_block_max=m_block_max,
30433043
)
30443044
m_block_oob = m_block >= m_block_max
@@ -3445,7 +3445,7 @@ def _dq_semaphore_lock_value(
34453445
if const_expr(self.use_block_sparsity):
34463446
assert blocksparse_tensors is not None
34473447
if const_expr(blocksparse_tensors.dq_write_order is not None):
3448-
sparse_iter = iter_idx // self.subtile_factor
3448+
sparse_iter = iter_idx // self.q_subtile_factor
34493449
if sparse_iter < curr_q_cnt:
34503450
assert curr_dq_write_order is not None
34513451
lock_value = curr_dq_write_order[sparse_iter]
@@ -3554,7 +3554,7 @@ def dQacc_reduce(
35543554
batch_idx,
35553555
head_idx,
35563556
n_block,
3557-
subtile_factor=self.subtile_factor,
3557+
q_subtile_factor=self.q_subtile_factor,
35583558
m_block_max=m_block_max,
35593559
)
35603560
process_tile = loop_count > Int32(0)
@@ -3584,7 +3584,7 @@ def dQacc_reduce(
35843584
curr_q_idx,
35853585
curr_full_cnt,
35863586
curr_full_idx,
3587-
subtile_factor=self.subtile_factor,
3587+
q_subtile_factor=self.q_subtile_factor,
35883588
m_block_max=m_block_max,
35893589
)
35903590
m_block_oob_upper = m_block >= m_block_max

0 commit comments

Comments
 (0)