-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathblock_sparse_utils.py
More file actions
1453 lines (1307 loc) · 52.5 KB
/
block_sparse_utils.py
File metadata and controls
1453 lines (1307 loc) · 52.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Block-sparse runtime utilities for CUTE DSL kernels.
This module contains runtime execution functions for block-sparse attention kernels.
These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
"""
from typing import Callable, Optional
from functools import partial
import math
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, const_expr
# Import data structures from block_sparsity
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
from flash_attn.cute.named_barrier import NamedBarrierBwd
@cute.jit
def load_block_list(
block_indices: cute.Tensor,
block_count,
load_q_with_first: cutlass.Constexpr,
first_block_preloaded: cutlass.Constexpr,
kv_producer_state,
load_Q,
load_K,
load_V,
pipeline_k,
pipeline_v,
use_tma_q: cutlass.Constexpr,
tma_q_bytes: cutlass.Constexpr,
intra_wg_overlap: cutlass.Constexpr,
):
"""Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
for the intra_wg_overlap case, we overlap the loads of K and V. And this
means we need to pipeline the last V load from the partial block case,
with the loads for the full blocks. Set first_block_preloaded when the
caller has already issued the first K load for the list.
Note:
we iterate along the block_n indices in reverse.
Returns:
Updated kv_producer_state after processing the block list.
"""
if block_count > 0:
if const_expr(not intra_wg_overlap):
# Peel first iteration: the first block may need to load Q alongside K,
# Parameters are already Constexpr, so no need to wrap in const_expr()
n_block_first = block_indices[block_count - 1]
extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
if const_expr(load_q_with_first and use_tma_q):
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
pipeline_v.producer_acquire(kv_producer_state)
load_V(src_idx=n_block_first, producer_state=kv_producer_state)
kv_producer_state.advance()
for offset in cutlass.range(1, block_count):
n_block = block_indices[block_count - 1 - offset]
pipeline_k.producer_acquire(kv_producer_state)
load_K(src_idx=n_block, producer_state=kv_producer_state)
pipeline_v.producer_acquire(kv_producer_state)
load_V(src_idx=n_block, producer_state=kv_producer_state)
kv_producer_state.advance()
else:
n_block_first = block_indices[block_count - 1]
if const_expr(not first_block_preloaded):
extra_tx = (
tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
)
pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
if const_expr(load_q_with_first and use_tma_q):
load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
load_K(src_idx=n_block_first, producer_state=kv_producer_state)
for idx in cutlass.range(block_count - 1, unroll=1):
n_block_prev = block_indices[block_count - 1 - idx]
n_block = block_indices[block_count - 2 - idx]
kv_producer_state_prev = kv_producer_state.clone()
kv_producer_state.advance()
pipeline_k.producer_acquire(kv_producer_state)
load_K(src_idx=n_block, producer_state=kv_producer_state)
pipeline_v.producer_acquire(kv_producer_state_prev)
load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
return kv_producer_state
@cute.jit
def finish_overlap_v_load(
block_indices: cute.Tensor,
block_count,
load_V,
pipeline_v,
kv_producer_state,
):
"""Load the final V block after overlapped K/V loads."""
if block_count > 0:
n_block_last = block_indices[0]
pipeline_v.producer_acquire(kv_producer_state)
load_V(src_idx=n_block_last, producer_state=kv_producer_state)
kv_producer_state.advance()
return kv_producer_state
@cute.jit
def sparse_tensor_m_block(
m_block,
qhead_per_kvhead: cutlass.Constexpr[int],
):
"""Map packed m_block indices to block-sparse tensor indices."""
if const_expr(qhead_per_kvhead != 1):
return m_block // qhead_per_kvhead
return m_block
@cute.jit
def produce_block_sparse_loads(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
m_block,
kv_producer_state,
load_Q,
load_K,
load_V,
pipeline_k,
pipeline_v,
use_tma_q: cutlass.Constexpr,
tma_q_bytes: cutlass.Constexpr,
intra_wg_overlap: cutlass.Constexpr,
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
):
"""Iterate over the mask and full block lists for a single tile.
The masked (partial) list may leave the last V load pending when intra-warp-group
overlap is enabled. The first full block must consume that pending V while
issuing its own K load on the next pipeline stage.
In the intra-wg-overlap path, the last masked block leaves its V copy in flight
while we advance the producer state to start the next full K. Either the full list
overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.
Args:
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
must be converted to unpacked for sparse tensor indexing.
"""
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
mask_empty = curr_mask_block_cnt == 0
full_empty = curr_full_block_cnt == 0
if mask_empty:
# No masked blocks: the full list owns the initial Q+K load.
kv_producer_state = load_block_list(
curr_full_block_idx,
curr_full_block_cnt,
load_q_with_first=True,
first_block_preloaded=False,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_k=pipeline_k,
pipeline_v=pipeline_v,
use_tma_q=use_tma_q,
tma_q_bytes=tma_q_bytes,
intra_wg_overlap=intra_wg_overlap,
)
if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:
kv_producer_state = finish_overlap_v_load(
curr_full_block_idx,
curr_full_block_cnt,
load_V,
pipeline_v,
kv_producer_state,
)
else:
# Masked blocks present: load Q together with the first masked K so consumers can
# start immediately. When overlap is disabled this fully drains the list.
kv_producer_state = load_block_list(
curr_mask_block_idx,
curr_mask_block_cnt,
load_q_with_first=True,
first_block_preloaded=False,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_k=pipeline_k,
pipeline_v=pipeline_v,
use_tma_q=use_tma_q,
tma_q_bytes=tma_q_bytes,
intra_wg_overlap=intra_wg_overlap,
)
if full_empty:
if const_expr(intra_wg_overlap):
kv_producer_state = finish_overlap_v_load(
curr_mask_block_idx,
curr_mask_block_cnt,
load_V,
pipeline_v,
kv_producer_state,
)
else:
if const_expr(intra_wg_overlap):
# Bridge the masked list to the full list by overlapping the pending masked V
# with the first full K load.
n_block_mask_last = curr_mask_block_idx[0]
n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]
kv_producer_state_prev = kv_producer_state.clone()
kv_producer_state.advance()
pipeline_k.producer_acquire(kv_producer_state)
load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)
pipeline_v.producer_acquire(kv_producer_state_prev)
load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)
kv_producer_state = load_block_list(
curr_full_block_idx,
curr_full_block_cnt,
load_q_with_first=False,
first_block_preloaded=True,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_k=pipeline_k,
pipeline_v=pipeline_v,
use_tma_q=use_tma_q,
tma_q_bytes=tma_q_bytes,
intra_wg_overlap=intra_wg_overlap,
)
kv_producer_state = finish_overlap_v_load(
curr_full_block_idx,
curr_full_block_cnt,
load_V,
pipeline_v,
kv_producer_state,
)
else:
# Non-overlap path with both lists: run the full list normally (skipping the Q
# reload because the masked list already issued it).
kv_producer_state = load_block_list(
curr_full_block_idx,
curr_full_block_cnt,
load_q_with_first=False,
first_block_preloaded=False,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_k=pipeline_k,
pipeline_v=pipeline_v,
use_tma_q=use_tma_q,
tma_q_bytes=tma_q_bytes,
intra_wg_overlap=intra_wg_overlap,
)
return kv_producer_state
@cute.jit
def consume_block_sparse_loads(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
m_block,
seqlen,
kv_consumer_state,
mma_pv_fn,
mma_one_n_block,
process_first_half_block,
process_last_half_block,
mask_fn,
score_mod_fn,
O_should_accumulate,
mask_mod,
fastdiv_mods,
intra_wg_overlap: cutlass.Constexpr,
warp_scheduler_barrier_sync: Callable,
warp_scheduler_barrier_arrive: Callable,
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
):
"""Consume the mask and full block lists for a single tile on the consumer side.
Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
the same sparse tensor indexing.
Args:
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
must be converted to unpacked for sparse tensor indexing.
"""
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0
if const_expr(not intra_wg_overlap):
if curr_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
warp_scheduler_barrier_sync()
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=mask_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(
mask_fn,
mask_mod=mask_mod,
mask_seqlen=True,
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
),
is_first_n_block=True,
)
O_should_accumulate = True
for i in cutlass.range(1, curr_mask_block_cnt):
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=mask_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
is_first_n_block=False,
)
O_should_accumulate = True
if curr_full_block_cnt == 0:
warp_scheduler_barrier_arrive()
if curr_full_block_cnt > 0:
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
if curr_mask_block_cnt == 0:
warp_scheduler_barrier_sync()
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_seqlen=True),
is_first_n_block=True,
)
O_should_accumulate = True
for i in cutlass.range(1, curr_full_block_cnt):
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_seqlen=False),
is_first_n_block=False,
)
O_should_accumulate = True
else:
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
is_first_n_block=False,
)
O_should_accumulate = True
for i in cutlass.range(1, curr_full_block_cnt):
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
is_first_n_block=False,
)
O_should_accumulate = True
warp_scheduler_barrier_arrive()
else:
if curr_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
kv_consumer_state = process_first_half_block(
n_block=mask_n_block,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(
mask_fn,
mask_mod=mask_mod,
mask_seqlen=True,
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
),
score_mod_fn=score_mod_fn,
is_first_block=True,
)
for i in cutlass.range(1, curr_mask_block_cnt):
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=mask_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
)
O_should_accumulate = True
if curr_full_block_cnt > 0:
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
if curr_mask_block_cnt == 0:
kv_consumer_state = process_first_half_block(
n_block=full_n_block,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
score_mod_fn=score_mod_fn,
is_first_block=True,
)
else:
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
)
O_should_accumulate = True
for i in cutlass.range(1, curr_full_block_cnt):
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
)
O_should_accumulate = True
if curr_mask_block_cnt + curr_full_block_cnt > 0:
kv_consumer_state = process_last_half_block(
kv_consumer_state=kv_consumer_state,
zero_init=not O_should_accumulate,
)
O_should_accumulate = True
return kv_consumer_state, O_should_accumulate, processed_any
@cute.jit
def load_block_list_sm100(
block_indices: cute.Tensor,
block_count,
load_q_with_first: cutlass.Constexpr,
m_block,
q_stage: cutlass.Constexpr,
kv_producer_state,
load_Q,
load_K,
load_V,
pipeline_kv,
):
"""SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
if block_count > 0:
# First iteration: load Q alongside K if requested
n_block_first = block_indices[block_count - 1]
if const_expr(load_q_with_first):
# SM100 loads Q0 and optionally Q1
load_Q(block=q_stage * m_block + 0, stage=0)
if const_expr(q_stage == 2):
load_Q(block=q_stage * m_block + 1, stage=1)
# SM100 doesn't use producer_acquire for pipeline_kv in load path
# The pipeline barriers are handled inside load_KV
load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
kv_producer_state.advance()
load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
kv_producer_state.advance()
# Remaining blocks
for offset in cutlass.range(1, block_count):
n_block = block_indices[block_count - 1 - offset]
load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
kv_producer_state.advance()
load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
kv_producer_state.advance()
return kv_producer_state
# SM100-specific tile processor using SM100 helpers
@cute.jit
def produce_block_sparse_loads_sm100(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
m_block,
kv_producer_state,
load_Q,
load_K,
load_V,
pipeline_kv,
q_stage: cutlass.Constexpr,
q_producer_phase: Int32,
qhead_per_kvhead: cutlass.Constexpr,
):
"""SM100 entry point for sparse block iteration.
SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
simplified block processing that just calls producer_acquire without extras.
Args:
m_block: which tile of m we are processing
qhead_per_kvhead: Constexpr pack factor
"""
# NB: Compute unpacked index for sparse tensor access
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
mask_empty = curr_mask_block_cnt == 0
full_empty = curr_full_block_cnt == 0
q_phase_flipped = False
if mask_empty:
# No masked blocks: process full list with Q loading
kv_producer_state = load_block_list_sm100(
curr_full_block_idx,
curr_full_block_cnt,
load_q_with_first=True,
m_block=m_block,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_kv=pipeline_kv,
)
q_phase_flipped = not full_empty
else:
# Process masked blocks with Q loading
kv_producer_state = load_block_list_sm100(
curr_mask_block_idx,
curr_mask_block_cnt,
load_q_with_first=True,
m_block=m_block,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_kv=pipeline_kv,
)
q_phase_flipped = True
if not full_empty:
# Process full blocks without Q loading
kv_producer_state = load_block_list_sm100(
curr_full_block_idx,
curr_full_block_cnt,
load_q_with_first=False,
m_block=m_block,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
load_Q=load_Q,
load_K=load_K,
load_V=load_V,
pipeline_kv=pipeline_kv,
)
if q_phase_flipped:
q_producer_phase ^= 1
return kv_producer_state, q_producer_phase
@cute.jit
def get_total_block_count(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
m_block,
qhead_per_kvhead: cutlass.Constexpr,
):
# NB: Convert packed m_block to unpacked for sparse tensor indexing
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
if const_expr(full_block_cnt is not None):
return (
mask_block_cnt[batch_idx, head_idx, m_block_sparse]
+ full_block_cnt[batch_idx, head_idx, m_block_sparse]
)
else:
return mask_block_cnt[batch_idx, head_idx, m_block_sparse]
@cute.jit
def handle_block_sparse_empty_tile_correction_sm100(
tidx: Int32,
q_stage: cutlass.Constexpr,
m_block_size: cutlass.Constexpr,
qhead_per_kvhead,
pack_gqa: cutlass.Constexpr,
is_split_kv: cutlass.Constexpr,
learnable_sink,
mLSE,
seqlen,
m_block: Int32,
head_idx: Int32,
batch_idx: Int32,
split_idx: Int32,
sScale: cute.Tensor,
stats: list,
correction_epilogue: Callable,
thr_mma_pv: cute.core.ThrMma,
tOtOs: tuple[cute.Tensor],
sO: cute.Tensor,
mbar_ptr,
mbar_softmax_corr_full_offset: Int32,
mbar_softmax_corr_empty_offset: Int32,
mbar_P_full_O_rescaled_offset: Int32,
mbar_P_full_2_offset: Int32,
mbar_corr_epi_full_offset: Int32,
mbar_corr_epi_empty_offset: Int32,
softmax_corr_consumer_phase: Int32,
o_corr_consumer_phase: Int32,
corr_epi_producer_phase: Int32,
softmax_scale_log2: Float32,
max_offset: Float32,
max_offset_scale: Float32,
mO_cur: Optional[cute.Tensor] = None,
gO: Optional[cute.Tensor] = None,
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
):
"""Handle the block-sparse case where a tile is fully masked:
* zero staged results
* seed stats
* satisfy the usual barrier protocol so downstream warps continue to make progress.
"""
LOG2_E = Float32(math.log2(math.e))
for stage in cutlass.range_constexpr(q_stage):
row_sum_value = Float32(1.0)
row_max_value = (
-Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None
)
if const_expr(learnable_sink is not None):
sink_val = -Float32.inf
if const_expr(not pack_gqa):
sink_val = Float32(learnable_sink[head_idx])
elif tidx < m_block_size:
q_head_idx = (
(q_stage * m_block + stage) * m_block_size + tidx
) % qhead_per_kvhead + head_idx * qhead_per_kvhead
sink_val = Float32(learnable_sink[q_head_idx])
if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
if row_max_value == -Float32.inf:
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
row_sum_value = max_offset_scale
else:
row_sum_value = row_sum_value + utils.exp2f(
sink_val * LOG2_E - row_max_value * softmax_scale_log2 + max_offset
)
if tidx < m_block_size:
scale_row_idx = tidx + stage * m_block_size
sScale[scale_row_idx] = row_sum_value
if const_expr(mLSE is not None or learnable_sink is not None):
sScale[scale_row_idx + m_block_size * 2] = row_max_value
acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
stats[stage] = (row_sum_value, row_max_value, acc_flag)
cute.arch.mbarrier_wait(
mbar_ptr + mbar_softmax_corr_full_offset + stage,
softmax_corr_consumer_phase,
)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage)
if const_expr(gmem_tiled_copy_O is None):
cute.arch.mbarrier_wait(
mbar_ptr + mbar_corr_epi_empty_offset + stage,
corr_epi_producer_phase,
)
correction_epilogue(
thr_mma_pv,
tOtOs[stage],
tidx,
stage,
m_block,
seqlen.seqlen_q,
Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
sO[None, None, stage],
mO_cur,
gO,
gmem_tiled_copy_O,
)
if const_expr(gmem_tiled_copy_O is None):
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)
softmax_corr_consumer_phase ^= 1
o_corr_consumer_phase ^= 1
corr_epi_producer_phase ^= 1
return (
softmax_corr_consumer_phase,
o_corr_consumer_phase,
corr_epi_producer_phase,
)
@cute.jit
def softmax_block_sparse_sm100(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
m_block,
softmax_step: Callable,
mask_fn: Callable,
mask_fn_none: Callable,
mma_si_consumer_phase: Int32,
si_corr_producer_phase: Int32,
s0_s1_sequence_phase: Int32,
mbar_ptr,
mbar_softmax_corr_full_offset: Int32,
mbar_softmax_corr_empty_offset: Int32,
mbar_P_full_O_rescaled_offset: Int32,
mbar_P_full_2_offset: Int32,
q_stage: cutlass.Constexpr,
stage_idx: Int32,
check_m_boundary: bool,
qhead_per_kvhead: cutlass.Constexpr,
):
# Convert packed m_block to unpacked for sparse tensor indexing
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
if total_block_cnt == 0:
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx)
else:
if curr_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
) = softmax_step(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
mask_n_block,
is_first=True,
mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
)
for i in cutlass.range(1, curr_mask_block_cnt):
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
) = softmax_step(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
mask_n_block,
mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
)
if curr_full_block_cnt > 0:
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
if curr_mask_block_cnt == 0:
(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
) = softmax_step(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
full_n_block,
is_first=True,
mask_fn=partial(
mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
),
)
else:
(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
) = softmax_step(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
full_n_block,
is_first=False,
mask_fn=partial(
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
),
)
for i in cutlass.range(1, curr_full_block_cnt):
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
) = softmax_step(
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
full_n_block,
mask_fn=partial(
mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
),
)
return (
mma_si_consumer_phase,
si_corr_producer_phase,
s0_s1_sequence_phase,
total_block_cnt == 0,
)
# =============================================================================
# Backward-specific block-sparse helpers (SM100)
# =============================================================================
#
# In backward, iteration is transposed compared to forward:
# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
#
# The backward block-sparse tensors use "Q direction" indexing:
# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
#
@cute.jit
def get_total_q_block_count_bwd(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
n_block,
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
):
"""Count total tile iterations for given n_block (KV tile) in backward."""
q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
total = q_block_cnt[batch_idx, head_idx, n_block]
if const_expr(full_block_cnt is not None):
total = total + full_block_cnt[batch_idx, head_idx, n_block]
return total * subtile_factor
@cute.jit
def produce_block_sparse_q_loads_bwd_sm100(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
n_block,
# Pipeline states (will be returned after advancing)
producer_state_Q_LSE,
producer_state_dO_dPsum,
# Pipelines
pipeline_Q,
pipeline_LSE,
pipeline_dO,
pipeline_dPsum,
# Load functions
load_K,
load_V,
load_Q,
load_dO,
copy_stats,
# Global tensors for LSE/dPsum
gLSE,
sLSE,
gdPsum,
sdPsum,
# TMA copy bytes for extra_tx_count
tma_copy_bytes_K,
tma_copy_bytes_V,
# Flags for which loads to perform
should_load_Q: cutlass.Constexpr,
should_load_dO: cutlass.Constexpr,
# Subtiling factor and bounds
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
):
"""SM100 backward block sparse loading with subtiling.
Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
"""
(
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
loop_count,
) = get_block_sparse_iteration_info_bwd(
blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
)
for iter_idx in cutlass.range(loop_count, unroll=1):
m_block, _ = get_m_block_from_iter_bwd(
iter_idx,
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
subtile_factor,
m_block_max,
)
m_block_safe = m_block
if m_block_max > 0:
m_block_safe = cutlass.min(m_block, m_block_max - 1)
if iter_idx == 0:
# First block: load K/V alongside Q/dO
if const_expr(should_load_Q):
pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
pipeline_Q.producer_commit(producer_state_Q_LSE)
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
with cute.arch.elect_one():
copy_stats(
gLSE[None, m_block_safe],
sLSE[None, producer_state_Q_LSE.index],
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
)
producer_state_Q_LSE.advance()
if const_expr(should_load_dO):
pipeline_dO.producer_acquire(
producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
)
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))