@@ -36,7 +36,7 @@ class AttentionConfig:
3636 p_layout : gl .constexpr
3737
3838 @gluon .constexpr_function
39- def __init__ (self , SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS ):
39+ def __init__ (self , SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS , NUM_WARPS ):
4040
4141 # constants
4242 self .SEQLEN_Q = gl .constexpr (SEQLEN_Q )
@@ -46,11 +46,17 @@ def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS):
4646 self .BLOCK_N = gl .constexpr (BLOCK_N )
4747 self .NUM_BUFFERS = gl .constexpr (NUM_BUFFERS )
4848
49+ assert NUM_WARPS == 4 or NUM_WARPS == 8
50+ if NUM_WARPS == 4 :
51+ warp_bases = [[1 , 0 ], [2 , 0 ]]
52+ else :
53+ warp_bases = [[1 , 0 ], [2 , 0 ], [4 , 0 ]]
54+
4955 # operator layouts
5056 self .qk_layout = gl .constexpr (
51- gl .amd .AMDWMMALayout (3 , transposed = True , warp_bases = [[ 1 , 0 ], [ 2 , 0 ]] , instr_shape = [16 , 16 , 32 ]))
57+ gl .amd .AMDWMMALayout (3 , transposed = True , warp_bases = warp_bases , instr_shape = [16 , 16 , 32 ]))
5258 self .pv_layout = gl .constexpr (
53- gl .amd .AMDWMMALayout (3 , transposed = True , warp_bases = [[ 1 , 0 ], [ 2 , 0 ]] , instr_shape = [16 , 16 , 32 ]))
59+ gl .amd .AMDWMMALayout (3 , transposed = True , warp_bases = warp_bases , instr_shape = [16 , 16 , 32 ]))
5460
5561 # tensor layouts
5662 self .k_smem_layout = gl .constexpr (
@@ -258,7 +264,8 @@ def attn_fwd_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
258264 ):
259265
260266 NUM_BUFFERS : gl .constexpr = 1
261- cfg = AttentionConfig (SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS )
267+ NUM_WARPS : gl .constexpr = 4
268+ cfg = AttentionConfig (SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS , NUM_WARPS )
262269 pgm = AttentionProgram .initialize ( #
263270 cfg , q_ptr , k_ptr , v_ptr , out_ptr , #
264271 stride_qz , stride_qh , stride_qm , stride_qk , #
@@ -307,7 +314,8 @@ def attn_fwd_pipelined_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
307314 HEAD_SZ : gl .constexpr , #
308315 ):
309316 NUM_BUFFERS : gl .constexpr = 2
310- cfg = AttentionConfig (SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS )
317+ NUM_WARPS : gl .constexpr = 4
318+ cfg = AttentionConfig (SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS , NUM_WARPS )
311319 pgm = AttentionProgram .initialize ( #
312320 cfg , q_ptr , k_ptr , v_ptr , out_ptr , #
313321 stride_qz , stride_qh , stride_qm , stride_qk , #
@@ -502,57 +510,286 @@ def attn_fwd_pipelined_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
502510 pgm .store_output (acc )
503511
504512
513+ @gluon .jit
514+ def attn_fwd_pingpong_pipelined_kernel (q_ptr , k_ptr , v_ptr , out_ptr , #
515+ stride_qz , stride_qh , stride_qm , stride_qk , #
516+ stride_kz , stride_kh , stride_kn , stride_kk , #
517+ stride_vz , stride_vh , stride_vn , stride_vk , #
518+ stride_oz , stride_oh , stride_om , stride_on , #
519+ SM_SCALE : gl .constexpr , #
520+ SEQLEN_Q : gl .constexpr , #
521+ SEQLEN_K : gl .constexpr , #
522+ BLOCK_M : gl .constexpr , #
523+ BLOCK_N : gl .constexpr , #
524+ HEAD_SZ : gl .constexpr , #
525+ ):
526+ NUM_BUFFERS : gl .constexpr = 2
527+ NUM_WARPS : gl .constexpr = 8
528+ cfg = AttentionConfig (SEQLEN_Q , SEQLEN_K , HEAD_SZ , BLOCK_M , BLOCK_N , NUM_BUFFERS , NUM_WARPS )
529+ pgm = AttentionProgram .initialize ( #
530+ cfg , q_ptr , k_ptr , v_ptr , out_ptr , #
531+ stride_qz , stride_qh , stride_qm , stride_qk , #
532+ stride_kz , stride_kh , stride_kn , stride_kk , #
533+ stride_vz , stride_vh , stride_vn , stride_vk , #
534+ stride_oz , stride_oh , stride_om , stride_on , #
535+ SM_SCALE )
536+
537+ ITERS_IN_PROLOGUE_EPILOGUE : gl .constexpr = 3
538+ n_blocks_n = max ((SEQLEN_K + BLOCK_N - 1 ) // BLOCK_N - ITERS_IN_PROLOGUE_EPILOGUE , 1 )
539+
540+ # Since QK from the final iteration is already peeled into the epilogue,
541+ # we only need to handle case where SEQLEN_K < ITERS_IN_PROLOGUE_EPILOGUE * BLOCK_N.
542+ has_remainder : gl .constexpr = SEQLEN_K < (ITERS_IN_PROLOGUE_EPILOGUE ) * BLOCK_N
543+ REMAINDER_PEELED_ITERS = 1
544+ if has_remainder :
545+ n_blocks_n = n_blocks_n - REMAINDER_PEELED_ITERS
546+
547+ m_i = gl .full ([BLOCK_M ], float ("-inf" ), dtype = gl .float32 , layout = gl .SliceLayout (1 , cfg .pv_layout ))
548+ l_i = gl .full ([BLOCK_M ], 1.0 , dtype = gl .float32 , layout = gl .SliceLayout (1 , cfg .pv_layout ))
549+ acc = gl .zeros ([BLOCK_M , HEAD_SZ ], dtype = gl .float32 , layout = cfg .pv_layout )
550+
551+ block_min = 0
552+ block_max = n_blocks_n * BLOCK_N
553+ """
554+ Prologue:
555+ t = i t = i+1 t = i+2
556+ [GLDS_K]
557+ [LR_K, GLDS_V], [GLDS_K]
558+ [QK, SM0], [LR_K, GLDS_V], [GLDS_K]
559+ """
560+ # GLDS_K_t0, GLDS_K_t1, GLDS_V_t0
561+ pgm .tdm_load_global_to_shared_k ([0 , 0 ], buffer_index = 0 )
562+ pgm .tdm_load_global_to_shared_k ([BLOCK_N , 0 ], buffer_index = 1 )
563+ pgm .tdm_load_global_to_shared_v ([0 , 0 ], buffer_index = 0 )
564+
565+ # LR_K_t0
566+ k = pgm .tdm_shared_load_k (0 , wait_count = 2 )
567+
568+ # QK_t0
569+ qk = pgm .compute_qk (k , 0 )
570+
571+ # SM0_t0
572+ p , alpha , m_i = pgm .softmax_part0 (qk , m_i )
573+
574+ # GLDS_V_t1, GLDS_K_t2
575+ pgm .tdm_load_global_to_shared_v ([BLOCK_N , 0 ], buffer_index = 1 )
576+ pgm .tdm_load_global_to_shared_k ([2 * BLOCK_N , 0 ], buffer_index = 0 )
577+
578+ # LR_K_t1
579+ k = pgm .tdm_shared_load_k (1 , wait_count = 3 )
580+ iter_id = 0
581+ for block_id in range (block_min , block_max , BLOCK_N ):
582+ """
583+ Steady State (Hot Loop - No Masking):
584+ t = i t = i+1 t = i+2 t = i+3
585+ [SM1, LR_V, PV], [QK, SM0], [LR_K, GLDS_V] [GLDS_K]
586+
587+ unroll_factor=2 to save computation wrt iter_id and arithmetic computation
588+ for rotating registers.
589+ """
590+ """
591+ 1/2 of unrolled loop
592+ """
593+
594+ # QK, SM1, LR_V (no mask needed - all blocks in hot loop are full)
595+ with gl .amd .warp_pipeline_stage ("stage0" , priority = 0 ):
596+ t_1 = block_id + BLOCK_N
597+ t_2 = block_id + 2 * BLOCK_N
598+ t_3 = block_id + 3 * BLOCK_N
599+ qk = pgm .compute_qk_no_mask (k )
600+
601+ gl .amd .gfx1250 .tdm .async_wait (2 )
602+ with gl .amd .warp_pipeline_stage ("stage1" , priority = 1 ):
603+ # v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=2)
604+ p , l_i , acc = pgm .softmax_part1 (p , l_i , acc , alpha )
605+ v = pgm .v_buffer .index (iter_id % NUM_BUFFERS ).load (layout = pgm .cfg .v_layout )
606+ pgm .tdm_load_global_to_shared_k ([t_3 , 0 ], (iter_id + 1 ) % NUM_BUFFERS )
607+
608+ # PV, SM0, LR_K
609+ with gl .amd .warp_pipeline_stage ("stage2" , priority = 0 ):
610+ acc = pgm .compute_pv (p , v , acc )
611+
612+ gl .amd .gfx1250 .tdm .async_wait (2 )
613+ with gl .amd .warp_pipeline_stage ("stage3" , priority = 1 ):
614+ # k = pgm.tdm_shared_load_k(iter_id % NUM_BUFFERS, wait_count=2)
615+ p , alpha , m_i = pgm .softmax_part0 (qk , m_i )
616+ k = pgm .k_buffer .index (iter_id % NUM_BUFFERS ).permute ([1 , 0 ]).load (layout = pgm .cfg .k_layout )
617+ pgm .tdm_load_global_to_shared_v ([t_2 , 0 ], iter_id % NUM_BUFFERS )
618+ iter_id += 1
619+ """
620+ Final iteration of steady state that requires masking.(if masking is required)
621+ """
622+ if has_remainder :
623+ t_1 = iter_id * BLOCK_N + BLOCK_N
624+ t_2 = iter_id * BLOCK_N + 2 * BLOCK_N
625+ t_3 = iter_id * BLOCK_N + 3 * BLOCK_N
626+
627+ # Process the remainder block with masking
628+ qk = pgm .compute_qk (k , t_1 )
629+
630+ p , l_i , acc = pgm .softmax_part1 (p , l_i , acc , alpha )
631+
632+ v = pgm .tdm_shared_load_v (iter_id % NUM_BUFFERS , wait_count = 2 )
633+
634+ # GLDS_K
635+ pgm .tdm_load_global_to_shared_k ([t_3 , 0 ], (iter_id + 1 ) % NUM_BUFFERS )
636+
637+ # PV, SM0, LR_K
638+ acc = pgm .compute_pv (p , v , acc )
639+
640+ p , alpha , m_i = pgm .softmax_part0 (qk , m_i )
641+
642+ k = pgm .tdm_shared_load_k (iter_id % NUM_BUFFERS , wait_count = 2 )
643+
644+ # GLDS_V
645+ pgm .tdm_load_global_to_shared_v ([t_2 , 0 ], iter_id % NUM_BUFFERS )
646+ iter_id += 1
647+ """
648+ Epilogue:
649+ t = i+1 t = i+2 t = i+3
650+ [SM1, LR_V, PV], [QK, SM0], [LR_K, GLDS_V]
651+ [SM1, LR_V, PV], [QK, SM0]
652+ [SM1, LR_V, PV]
653+ """
654+ epilogue_offset = (iter_id - 1 ) * BLOCK_N
655+ t_2 = epilogue_offset + 2 * BLOCK_N
656+ t_3 = epilogue_offset + 3 * BLOCK_N
657+ # SM1_t1, LR_V_t1, PV_t1
658+ p , l_i , acc = pgm .softmax_part1 (p , l_i , acc , alpha )
659+
660+ v = pgm .tdm_shared_load_v (iter_id % NUM_BUFFERS , wait_count = 2 )
661+
662+ acc = pgm .compute_pv (p , v , acc )
663+
664+ # QK_t2, SM0_t2
665+ qk = pgm .compute_qk (k , t_2 )
666+ p , alpha , m_i = pgm .softmax_part0 (qk , m_i )
667+
668+ # LR_K_t3, GLDS_V_t3
669+ k = pgm .tdm_shared_load_k (iter_id % NUM_BUFFERS , wait_count = 1 )
670+
671+ pgm .tdm_load_global_to_shared_v ([t_3 , 0 ], iter_id % NUM_BUFFERS )
672+
673+ # QK_t3, SM1_t2, LR_V_t2
674+ qk = pgm .compute_qk (k , t_3 )
675+
676+ p , l_i , acc = pgm .softmax_part1 (p , l_i , acc , alpha )
677+
678+ v = pgm .tdm_shared_load_v ((iter_id + 1 ) % NUM_BUFFERS , wait_count = 1 )
679+
680+ # PV_t_2, SM0_t_3, SM1_t_3, LR_V_t3
681+ acc = pgm .compute_pv (p , v , acc )
682+
683+ p , alpha , m_i = pgm .softmax_part0 (qk , m_i )
684+ p , l_i , acc = pgm .softmax_part1 (p , l_i , acc , alpha )
685+
686+ v = pgm .tdm_shared_load_v (iter_id % NUM_BUFFERS , wait_count = 0 )
687+
688+ # PV_t_3
689+ acc = pgm .compute_pv (p , v , acc )
690+
691+ # Post loop scaling and output
692+
693+ l_recip = 1 / l_i [:, None ]
694+ acc = acc * l_recip
695+ pgm .store_output (acc )
696+
697+
505698def generate_configs ():
506699 base_configs = [
507700 # Tests for pipelined attention fwd kernel
508701 pytest .param ({
509702 "BATCH" : 8 , "SEQLEN_Q" : 512 , "SEQLEN_K" : 512 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 128 , "BLOCK_M" :
510- 128 , "BLOCK_N" : 64 , "ATTN_FN" : attn_fwd_pipelined_kernel
703+ 128 , "BLOCK_N" : 64 , "ATTN_FN" : "pipeline"
511704 }),
512705 pytest .param ({
513706 "BATCH" : 8 , "SEQLEN_Q" : 1024 , "SEQLEN_K" : 1024 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 64 ,
514- "BLOCK_M" : 128 , "BLOCK_N" : 128 , "ATTN_FN" : attn_fwd_pipelined_kernel
707+ "BLOCK_M" : 128 , "BLOCK_N" : 128 , "ATTN_FN" : "pipeline"
515708 }),
516709 pytest .param ({
517710 "BATCH" : 4 , "SEQLEN_Q" : 2000 , "SEQLEN_K" : 2000 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 64 ,
518- "BLOCK_M" : 128 , "BLOCK_N" : 128 , "ATTN_FN" : attn_fwd_pipelined_kernel
711+ "BLOCK_M" : 128 , "BLOCK_N" : 128 , "ATTN_FN" : "pipeline"
519712 }),
520713 pytest .param ({
521714 "BATCH" : 1 , "SEQLEN_Q" : 3 , "SEQLEN_K" : 32 , "NUM_Q_HEADS" : 4 , "NUM_K_HEADS" : 4 , "HEAD_SZ" : 128 , "BLOCK_M" :
522- 128 , "BLOCK_N" : 32 , "ATTN_FN" : attn_fwd_pipelined_kernel
715+ 128 , "BLOCK_N" : 32 , "ATTN_FN" : "pipeline"
523716 }),
524717 pytest .param ({
525718 "BATCH" : 4 , "SEQLEN_Q" : 1 , "SEQLEN_K" : 100 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 32 , "BLOCK_M" :
526- 128 , "BLOCK_N" : 32 , "ATTN_FN" : attn_fwd_pipelined_kernel
719+ 128 , "BLOCK_N" : 32 , "ATTN_FN" : "pipeline"
527720 }),
528721 pytest .param ({
529722 "BATCH" : 1 , "SEQLEN_Q" : 1 , "SEQLEN_K" : 30 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 32 , "BLOCK_M" :
530- 128 , "BLOCK_N" : 32 , "ATTN_FN" : attn_fwd_pipelined_kernel
723+ 128 , "BLOCK_N" : 32 , "ATTN_FN" : "pipeline"
724+ }),
725+ # Tests for pingpong pipelined attention fwd kernel
726+ pytest .param ({
727+ "BATCH" : 8 , "SEQLEN_Q" : 1024 , "SEQLEN_K" : 1024 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 128 ,
728+ "BLOCK_M" : 256 , "BLOCK_N" : 64 , "ATTN_FN" : "pingpong"
729+ }),
730+ pytest .param ({
731+ "BATCH" : 1 , "SEQLEN_Q" : 300 , "SEQLEN_K" : 300 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 64 , "BLOCK_M" :
732+ 256 , "BLOCK_N" : 32 , "ATTN_FN" : "pingpong"
531733 }),
532734
533735 # Tests for non-pipelined attention fwd kernel
534736 pytest .param ({
535737 "BATCH" : 8 , "SEQLEN_Q" : 512 , "SEQLEN_K" : 512 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 128 , "BLOCK_M" :
536- 128 , "BLOCK_N" : 32 , "ATTN_FN" : attn_fwd_kernel
738+ 128 , "BLOCK_N" : 32 , "ATTN_FN" : "default"
537739 }),
538740 pytest .param ({
539741 "BATCH" : 1 , "SEQLEN_Q" : 1 , "SEQLEN_K" : 30 , "NUM_Q_HEADS" : 8 , "NUM_K_HEADS" : 8 , "HEAD_SZ" : 32 , "BLOCK_M" :
540- 128 , "BLOCK_N" : 32 , "ATTN_FN" : attn_fwd_kernel
742+ 128 , "BLOCK_N" : 32 , "ATTN_FN" : "default"
541743 }),
542744 ]
543745 return base_configs
544746
545747
546- def run_attention (config , check = True ):
748+ _KERNEL_NUM_WARPS = {attn_fwd_kernel : 4 , attn_fwd_pipelined_kernel : 4 , attn_fwd_pingpong_pipelined_kernel : 8 }
749+
750+ _ATTN_TYPE_TO_KERNEL_FN = {
751+ "default" : attn_fwd_kernel ,
752+ "pipeline" : attn_fwd_pipelined_kernel ,
753+ "pingpong" : attn_fwd_pingpong_pipelined_kernel ,
754+ }
755+
756+
757+ def run_prefill_attention (config , q , k , v , o , sm_scale ):
547758 BATCH = config ["BATCH" ]
548759 SEQLEN_Q = config ["SEQLEN_Q" ]
549760 SEQLEN_K = config ["SEQLEN_K" ]
550761 NUM_Q_HEADS = config ["NUM_Q_HEADS" ]
551- NUM_K_HEADS = config ["NUM_K_HEADS" ]
552762 HEAD_SZ = config ["HEAD_SZ" ]
553763 BLOCK_M = config ["BLOCK_M" ]
554764 BLOCK_N = config ["BLOCK_N" ]
555- attn_fn = config ["ATTN_FN" ]
765+ attn_fn = _ATTN_TYPE_TO_KERNEL_FN [config ["ATTN_FN" ]]
766+
767+ num_warps = _KERNEL_NUM_WARPS [attn_fn ]
768+
769+ grid = (
770+ BATCH ,
771+ NUM_Q_HEADS ,
772+ ((SEQLEN_Q + BLOCK_M - 1 ) // BLOCK_M ),
773+ )
774+ attn_kernel = attn_fn [grid ](
775+ q , k , v , o , #
776+ q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
777+ k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
778+ v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
779+ o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
780+ sm_scale , SEQLEN_Q , SEQLEN_K , #
781+ BLOCK_M , BLOCK_N , #
782+ HEAD_SZ , num_warps = num_warps , waves_per_eu = 1 )
783+ return (attn_kernel , )
784+
785+
786+ def run_attention (config , check = True ):
787+ BATCH = config ["BATCH" ]
788+ SEQLEN_Q = config ["SEQLEN_Q" ]
789+ SEQLEN_K = config ["SEQLEN_K" ]
790+ NUM_Q_HEADS = config ["NUM_Q_HEADS" ]
791+ NUM_K_HEADS = config ["NUM_K_HEADS" ]
792+ HEAD_SZ = config ["HEAD_SZ" ]
556793
557794 dtype = torch .bfloat16
558795 torch .random .manual_seed (0 )
@@ -570,21 +807,8 @@ def run_attention(config, check=True):
570807 v = v .cuda ()
571808 o = o .cuda ()
572809
573- grid = (
574- BATCH ,
575- NUM_Q_HEADS ,
576- ((SEQLEN_Q + BLOCK_M - 1 ) // BLOCK_M ),
577- )
810+ attn_kernel = run_prefill_attention (config , q , k , v , o , sm_scale )
578811
579- attn_kernel = attn_fn [grid ](
580- q , k , v , o , #
581- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
582- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
583- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
584- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
585- sm_scale , SEQLEN_Q , SEQLEN_K , #
586- BLOCK_M , BLOCK_N , #
587- HEAD_SZ , num_warps = 4 , waves_per_eu = 1 )
588812 torch .cuda .synchronize ()
589813 o = o .cpu ()
590814 rtol = 0.004
@@ -611,15 +835,21 @@ def test_attention(config):
611835 parser .add_argument ("--head-size" , type = int , default = 128 , help = 'Q/K/V head size' )
612836 parser .add_argument ("--block-m" , type = int , default = 128 , help = 'BLOCK_M size' )
613837 parser .add_argument ("--block-n" , type = int , default = 128 , help = 'BLOCK_N size' )
614- parser .add_argument ("--pipeline" , action = "store_true" , help = "Use pipelined variant" )
838+ parser .add_argument (
839+ "--attention-type" ,
840+ type = str ,
841+ choices = ["default" , "pipeline" , "pingpong" ],
842+ default = "default" ,
843+ help = "Attention Kernel Type" ,
844+ )
615845 args = parser .parse_args ()
616846 config = {
617847 "BATCH" : args .b , #
618848 "SEQLEN_Q" : args .seqlen_q , "SEQLEN_K" : args .seqlen_k , #
619849 "NUM_Q_HEADS" : args .num_heads_q , "NUM_K_HEADS" : args .num_heads_k , #
620850 "HEAD_SZ" : args .head_size , #
621851 "BLOCK_M" : args .block_m , "BLOCK_N" : args .block_n , #
622- "ATTN_FN" : attn_fwd_pipelined_kernel if args .pipeline else attn_fwd_kernel
852+ "ATTN_FN" : args .attention_type , #
623853 }
624854 print (config )
625855 run_attention (config )
0 commit comments