Skip to content

Commit 21a5f44

Browse files
authored
[AMD][Gluon] Implement 8 wave F16 FA PingPong kernel. (#9427)
Ping Pong is another way to implement a performant attention kernel allowing coexecution between two waves in same SIMD. In this PR we introduce a 8 wave PP/Warp pipelined kernel, and related num_warp changes. We also cleaned up the way to invoke/select different variants of attention kernel. Signed-off-by: Stanley Winata <stanley.winata@amd.com>
1 parent 3324fe6 commit 21a5f44

1 file changed

Lines changed: 262 additions & 32 deletions

File tree

third_party/amd/python/examples/gluon/f16_fa_gfx1250.py

Lines changed: 262 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class AttentionConfig:
3636
p_layout: gl.constexpr
3737

3838
@gluon.constexpr_function
39-
def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS):
39+
def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS):
4040

4141
# constants
4242
self.SEQLEN_Q = gl.constexpr(SEQLEN_Q)
@@ -46,11 +46,17 @@ def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS):
4646
self.BLOCK_N = gl.constexpr(BLOCK_N)
4747
self.NUM_BUFFERS = gl.constexpr(NUM_BUFFERS)
4848

49+
assert NUM_WARPS == 4 or NUM_WARPS == 8
50+
if NUM_WARPS == 4:
51+
warp_bases = [[1, 0], [2, 0]]
52+
else:
53+
warp_bases = [[1, 0], [2, 0], [4, 0]]
54+
4955
# operator layouts
5056
self.qk_layout = gl.constexpr(
51-
gl.amd.AMDWMMALayout(3, transposed=True, warp_bases=[[1, 0], [2, 0]], instr_shape=[16, 16, 32]))
57+
gl.amd.AMDWMMALayout(3, transposed=True, warp_bases=warp_bases, instr_shape=[16, 16, 32]))
5258
self.pv_layout = gl.constexpr(
53-
gl.amd.AMDWMMALayout(3, transposed=True, warp_bases=[[1, 0], [2, 0]], instr_shape=[16, 16, 32]))
59+
gl.amd.AMDWMMALayout(3, transposed=True, warp_bases=warp_bases, instr_shape=[16, 16, 32]))
5460

5561
# tensor layouts
5662
self.k_smem_layout = gl.constexpr(
@@ -258,7 +264,8 @@ def attn_fwd_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
258264
):
259265

260266
NUM_BUFFERS: gl.constexpr = 1
261-
cfg = AttentionConfig(SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS)
267+
NUM_WARPS: gl.constexpr = 4
268+
cfg = AttentionConfig(SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS)
262269
pgm = AttentionProgram.initialize( #
263270
cfg, q_ptr, k_ptr, v_ptr, out_ptr, #
264271
stride_qz, stride_qh, stride_qm, stride_qk, #
@@ -307,7 +314,8 @@ def attn_fwd_pipelined_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
307314
HEAD_SZ: gl.constexpr, #
308315
):
309316
NUM_BUFFERS: gl.constexpr = 2
310-
cfg = AttentionConfig(SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS)
317+
NUM_WARPS: gl.constexpr = 4
318+
cfg = AttentionConfig(SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS)
311319
pgm = AttentionProgram.initialize( #
312320
cfg, q_ptr, k_ptr, v_ptr, out_ptr, #
313321
stride_qz, stride_qh, stride_qm, stride_qk, #
@@ -502,57 +510,286 @@ def attn_fwd_pipelined_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
502510
pgm.store_output(acc)
503511

504512

513+
@gluon.jit
514+
def attn_fwd_pingpong_pipelined_kernel(q_ptr, k_ptr, v_ptr, out_ptr, #
515+
stride_qz, stride_qh, stride_qm, stride_qk, #
516+
stride_kz, stride_kh, stride_kn, stride_kk, #
517+
stride_vz, stride_vh, stride_vn, stride_vk, #
518+
stride_oz, stride_oh, stride_om, stride_on, #
519+
SM_SCALE: gl.constexpr, #
520+
SEQLEN_Q: gl.constexpr, #
521+
SEQLEN_K: gl.constexpr, #
522+
BLOCK_M: gl.constexpr, #
523+
BLOCK_N: gl.constexpr, #
524+
HEAD_SZ: gl.constexpr, #
525+
):
526+
NUM_BUFFERS: gl.constexpr = 2
527+
NUM_WARPS: gl.constexpr = 8
528+
cfg = AttentionConfig(SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS)
529+
pgm = AttentionProgram.initialize( #
530+
cfg, q_ptr, k_ptr, v_ptr, out_ptr, #
531+
stride_qz, stride_qh, stride_qm, stride_qk, #
532+
stride_kz, stride_kh, stride_kn, stride_kk, #
533+
stride_vz, stride_vh, stride_vn, stride_vk, #
534+
stride_oz, stride_oh, stride_om, stride_on, #
535+
SM_SCALE)
536+
537+
ITERS_IN_PROLOGUE_EPILOGUE: gl.constexpr = 3
538+
n_blocks_n = max((SEQLEN_K + BLOCK_N - 1) // BLOCK_N - ITERS_IN_PROLOGUE_EPILOGUE, 1)
539+
540+
# Since QK from the final iteration is already peeled into the epilogue,
541+
# we only need to handle case where SEQLEN_K < ITERS_IN_PROLOGUE_EPILOGUE * BLOCK_N.
542+
has_remainder: gl.constexpr = SEQLEN_K < (ITERS_IN_PROLOGUE_EPILOGUE) * BLOCK_N
543+
REMAINDER_PEELED_ITERS = 1
544+
if has_remainder:
545+
n_blocks_n = n_blocks_n - REMAINDER_PEELED_ITERS
546+
547+
m_i = gl.full([BLOCK_M], float("-inf"), dtype=gl.float32, layout=gl.SliceLayout(1, cfg.pv_layout))
548+
l_i = gl.full([BLOCK_M], 1.0, dtype=gl.float32, layout=gl.SliceLayout(1, cfg.pv_layout))
549+
acc = gl.zeros([BLOCK_M, HEAD_SZ], dtype=gl.float32, layout=cfg.pv_layout)
550+
551+
block_min = 0
552+
block_max = n_blocks_n * BLOCK_N
553+
"""
554+
Prologue:
555+
t = i t = i+1 t = i+2
556+
[GLDS_K]
557+
[LR_K, GLDS_V], [GLDS_K]
558+
[QK, SM0], [LR_K, GLDS_V], [GLDS_K]
559+
"""
560+
# GLDS_K_t0, GLDS_K_t1, GLDS_V_t0
561+
pgm.tdm_load_global_to_shared_k([0, 0], buffer_index=0)
562+
pgm.tdm_load_global_to_shared_k([BLOCK_N, 0], buffer_index=1)
563+
pgm.tdm_load_global_to_shared_v([0, 0], buffer_index=0)
564+
565+
# LR_K_t0
566+
k = pgm.tdm_shared_load_k(0, wait_count=2)
567+
568+
# QK_t0
569+
qk = pgm.compute_qk(k, 0)
570+
571+
# SM0_t0
572+
p, alpha, m_i = pgm.softmax_part0(qk, m_i)
573+
574+
# GLDS_V_t1, GLDS_K_t2
575+
pgm.tdm_load_global_to_shared_v([BLOCK_N, 0], buffer_index=1)
576+
pgm.tdm_load_global_to_shared_k([2 * BLOCK_N, 0], buffer_index=0)
577+
578+
# LR_K_t1
579+
k = pgm.tdm_shared_load_k(1, wait_count=3)
580+
iter_id = 0
581+
for block_id in range(block_min, block_max, BLOCK_N):
582+
"""
583+
Steady State (Hot Loop - No Masking):
584+
t = i t = i+1 t = i+2 t = i+3
585+
[SM1, LR_V, PV], [QK, SM0], [LR_K, GLDS_V] [GLDS_K]
586+
587+
unroll_factor=2 to save computation wrt iter_id and arithmetic computation
588+
for rotating registers.
589+
"""
590+
"""
591+
1/2 of unrolled loop
592+
"""
593+
594+
# QK, SM1, LR_V (no mask needed - all blocks in hot loop are full)
595+
with gl.amd.warp_pipeline_stage("stage0", priority=0):
596+
t_1 = block_id + BLOCK_N
597+
t_2 = block_id + 2 * BLOCK_N
598+
t_3 = block_id + 3 * BLOCK_N
599+
qk = pgm.compute_qk_no_mask(k)
600+
601+
gl.amd.gfx1250.tdm.async_wait(2)
602+
with gl.amd.warp_pipeline_stage("stage1", priority=1):
603+
# v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=2)
604+
p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
605+
v = pgm.v_buffer.index(iter_id % NUM_BUFFERS).load(layout=pgm.cfg.v_layout)
606+
pgm.tdm_load_global_to_shared_k([t_3, 0], (iter_id + 1) % NUM_BUFFERS)
607+
608+
# PV, SM0, LR_K
609+
with gl.amd.warp_pipeline_stage("stage2", priority=0):
610+
acc = pgm.compute_pv(p, v, acc)
611+
612+
gl.amd.gfx1250.tdm.async_wait(2)
613+
with gl.amd.warp_pipeline_stage("stage3", priority=1):
614+
# k = pgm.tdm_shared_load_k(iter_id % NUM_BUFFERS, wait_count=2)
615+
p, alpha, m_i = pgm.softmax_part0(qk, m_i)
616+
k = pgm.k_buffer.index(iter_id % NUM_BUFFERS).permute([1, 0]).load(layout=pgm.cfg.k_layout)
617+
pgm.tdm_load_global_to_shared_v([t_2, 0], iter_id % NUM_BUFFERS)
618+
iter_id += 1
619+
"""
620+
Final iteration of steady state that requires masking.(if masking is required)
621+
"""
622+
if has_remainder:
623+
t_1 = iter_id * BLOCK_N + BLOCK_N
624+
t_2 = iter_id * BLOCK_N + 2 * BLOCK_N
625+
t_3 = iter_id * BLOCK_N + 3 * BLOCK_N
626+
627+
# Process the remainder block with masking
628+
qk = pgm.compute_qk(k, t_1)
629+
630+
p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
631+
632+
v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=2)
633+
634+
# GLDS_K
635+
pgm.tdm_load_global_to_shared_k([t_3, 0], (iter_id + 1) % NUM_BUFFERS)
636+
637+
# PV, SM0, LR_K
638+
acc = pgm.compute_pv(p, v, acc)
639+
640+
p, alpha, m_i = pgm.softmax_part0(qk, m_i)
641+
642+
k = pgm.tdm_shared_load_k(iter_id % NUM_BUFFERS, wait_count=2)
643+
644+
# GLDS_V
645+
pgm.tdm_load_global_to_shared_v([t_2, 0], iter_id % NUM_BUFFERS)
646+
iter_id += 1
647+
"""
648+
Epilogue:
649+
t = i+1 t = i+2 t = i+3
650+
[SM1, LR_V, PV], [QK, SM0], [LR_K, GLDS_V]
651+
[SM1, LR_V, PV], [QK, SM0]
652+
[SM1, LR_V, PV]
653+
"""
654+
epilogue_offset = (iter_id - 1) * BLOCK_N
655+
t_2 = epilogue_offset + 2 * BLOCK_N
656+
t_3 = epilogue_offset + 3 * BLOCK_N
657+
# SM1_t1, LR_V_t1, PV_t1
658+
p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
659+
660+
v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=2)
661+
662+
acc = pgm.compute_pv(p, v, acc)
663+
664+
# QK_t2, SM0_t2
665+
qk = pgm.compute_qk(k, t_2)
666+
p, alpha, m_i = pgm.softmax_part0(qk, m_i)
667+
668+
# LR_K_t3, GLDS_V_t3
669+
k = pgm.tdm_shared_load_k(iter_id % NUM_BUFFERS, wait_count=1)
670+
671+
pgm.tdm_load_global_to_shared_v([t_3, 0], iter_id % NUM_BUFFERS)
672+
673+
# QK_t3, SM1_t2, LR_V_t2
674+
qk = pgm.compute_qk(k, t_3)
675+
676+
p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
677+
678+
v = pgm.tdm_shared_load_v((iter_id + 1) % NUM_BUFFERS, wait_count=1)
679+
680+
# PV_t_2, SM0_t_3, SM1_t_3, LR_V_t3
681+
acc = pgm.compute_pv(p, v, acc)
682+
683+
p, alpha, m_i = pgm.softmax_part0(qk, m_i)
684+
p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
685+
686+
v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=0)
687+
688+
# PV_t_3
689+
acc = pgm.compute_pv(p, v, acc)
690+
691+
# Post loop scaling and output
692+
693+
l_recip = 1 / l_i[:, None]
694+
acc = acc * l_recip
695+
pgm.store_output(acc)
696+
697+
505698
def generate_configs():
506699
base_configs = [
507700
# Tests for pipelined attention fwd kernel
508701
pytest.param({
509702
"BATCH": 8, "SEQLEN_Q": 512, "SEQLEN_K": 512, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 128, "BLOCK_M":
510-
128, "BLOCK_N": 64, "ATTN_FN": attn_fwd_pipelined_kernel
703+
128, "BLOCK_N": 64, "ATTN_FN": "pipeline"
511704
}),
512705
pytest.param({
513706
"BATCH": 8, "SEQLEN_Q": 1024, "SEQLEN_K": 1024, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 64,
514-
"BLOCK_M": 128, "BLOCK_N": 128, "ATTN_FN": attn_fwd_pipelined_kernel
707+
"BLOCK_M": 128, "BLOCK_N": 128, "ATTN_FN": "pipeline"
515708
}),
516709
pytest.param({
517710
"BATCH": 4, "SEQLEN_Q": 2000, "SEQLEN_K": 2000, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 64,
518-
"BLOCK_M": 128, "BLOCK_N": 128, "ATTN_FN": attn_fwd_pipelined_kernel
711+
"BLOCK_M": 128, "BLOCK_N": 128, "ATTN_FN": "pipeline"
519712
}),
520713
pytest.param({
521714
"BATCH": 1, "SEQLEN_Q": 3, "SEQLEN_K": 32, "NUM_Q_HEADS": 4, "NUM_K_HEADS": 4, "HEAD_SZ": 128, "BLOCK_M":
522-
128, "BLOCK_N": 32, "ATTN_FN": attn_fwd_pipelined_kernel
715+
128, "BLOCK_N": 32, "ATTN_FN": "pipeline"
523716
}),
524717
pytest.param({
525718
"BATCH": 4, "SEQLEN_Q": 1, "SEQLEN_K": 100, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 32, "BLOCK_M":
526-
128, "BLOCK_N": 32, "ATTN_FN": attn_fwd_pipelined_kernel
719+
128, "BLOCK_N": 32, "ATTN_FN": "pipeline"
527720
}),
528721
pytest.param({
529722
"BATCH": 1, "SEQLEN_Q": 1, "SEQLEN_K": 30, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 32, "BLOCK_M":
530-
128, "BLOCK_N": 32, "ATTN_FN": attn_fwd_pipelined_kernel
723+
128, "BLOCK_N": 32, "ATTN_FN": "pipeline"
724+
}),
725+
# Tests for pingpong pipelined attention fwd kernel
726+
pytest.param({
727+
"BATCH": 8, "SEQLEN_Q": 1024, "SEQLEN_K": 1024, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 128,
728+
"BLOCK_M": 256, "BLOCK_N": 64, "ATTN_FN": "pingpong"
729+
}),
730+
pytest.param({
731+
"BATCH": 1, "SEQLEN_Q": 300, "SEQLEN_K": 300, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 64, "BLOCK_M":
732+
256, "BLOCK_N": 32, "ATTN_FN": "pingpong"
531733
}),
532734

533735
# Tests for non-pipelined attention fwd kernel
534736
pytest.param({
535737
"BATCH": 8, "SEQLEN_Q": 512, "SEQLEN_K": 512, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 128, "BLOCK_M":
536-
128, "BLOCK_N": 32, "ATTN_FN": attn_fwd_kernel
738+
128, "BLOCK_N": 32, "ATTN_FN": "default"
537739
}),
538740
pytest.param({
539741
"BATCH": 1, "SEQLEN_Q": 1, "SEQLEN_K": 30, "NUM_Q_HEADS": 8, "NUM_K_HEADS": 8, "HEAD_SZ": 32, "BLOCK_M":
540-
128, "BLOCK_N": 32, "ATTN_FN": attn_fwd_kernel
742+
128, "BLOCK_N": 32, "ATTN_FN": "default"
541743
}),
542744
]
543745
return base_configs
544746

545747

546-
def run_attention(config, check=True):
748+
_KERNEL_NUM_WARPS = {attn_fwd_kernel: 4, attn_fwd_pipelined_kernel: 4, attn_fwd_pingpong_pipelined_kernel: 8}
749+
750+
_ATTN_TYPE_TO_KERNEL_FN = {
751+
"default": attn_fwd_kernel,
752+
"pipeline": attn_fwd_pipelined_kernel,
753+
"pingpong": attn_fwd_pingpong_pipelined_kernel,
754+
}
755+
756+
757+
def run_prefill_attention(config, q, k, v, o, sm_scale):
547758
BATCH = config["BATCH"]
548759
SEQLEN_Q = config["SEQLEN_Q"]
549760
SEQLEN_K = config["SEQLEN_K"]
550761
NUM_Q_HEADS = config["NUM_Q_HEADS"]
551-
NUM_K_HEADS = config["NUM_K_HEADS"]
552762
HEAD_SZ = config["HEAD_SZ"]
553763
BLOCK_M = config["BLOCK_M"]
554764
BLOCK_N = config["BLOCK_N"]
555-
attn_fn = config["ATTN_FN"]
765+
attn_fn = _ATTN_TYPE_TO_KERNEL_FN[config["ATTN_FN"]]
766+
767+
num_warps = _KERNEL_NUM_WARPS[attn_fn]
768+
769+
grid = (
770+
BATCH,
771+
NUM_Q_HEADS,
772+
((SEQLEN_Q + BLOCK_M - 1) // BLOCK_M),
773+
)
774+
attn_kernel = attn_fn[grid](
775+
q, k, v, o, #
776+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
777+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
778+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
779+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
780+
sm_scale, SEQLEN_Q, SEQLEN_K, #
781+
BLOCK_M, BLOCK_N, #
782+
HEAD_SZ, num_warps=num_warps, waves_per_eu=1)
783+
return (attn_kernel, )
784+
785+
786+
def run_attention(config, check=True):
787+
BATCH = config["BATCH"]
788+
SEQLEN_Q = config["SEQLEN_Q"]
789+
SEQLEN_K = config["SEQLEN_K"]
790+
NUM_Q_HEADS = config["NUM_Q_HEADS"]
791+
NUM_K_HEADS = config["NUM_K_HEADS"]
792+
HEAD_SZ = config["HEAD_SZ"]
556793

557794
dtype = torch.bfloat16
558795
torch.random.manual_seed(0)
@@ -570,21 +807,8 @@ def run_attention(config, check=True):
570807
v = v.cuda()
571808
o = o.cuda()
572809

573-
grid = (
574-
BATCH,
575-
NUM_Q_HEADS,
576-
((SEQLEN_Q + BLOCK_M - 1) // BLOCK_M),
577-
)
810+
attn_kernel = run_prefill_attention(config, q, k, v, o, sm_scale)
578811

579-
attn_kernel = attn_fn[grid](
580-
q, k, v, o, #
581-
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
582-
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
583-
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
584-
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
585-
sm_scale, SEQLEN_Q, SEQLEN_K, #
586-
BLOCK_M, BLOCK_N, #
587-
HEAD_SZ, num_warps=4, waves_per_eu=1)
588812
torch.cuda.synchronize()
589813
o = o.cpu()
590814
rtol = 0.004
@@ -611,15 +835,21 @@ def test_attention(config):
611835
parser.add_argument("--head-size", type=int, default=128, help='Q/K/V head size')
612836
parser.add_argument("--block-m", type=int, default=128, help='BLOCK_M size')
613837
parser.add_argument("--block-n", type=int, default=128, help='BLOCK_N size')
614-
parser.add_argument("--pipeline", action="store_true", help="Use pipelined variant")
838+
parser.add_argument(
839+
"--attention-type",
840+
type=str,
841+
choices=["default", "pipeline", "pingpong"],
842+
default="default",
843+
help="Attention Kernel Type",
844+
)
615845
args = parser.parse_args()
616846
config = {
617847
"BATCH": args.b, #
618848
"SEQLEN_Q": args.seqlen_q, "SEQLEN_K": args.seqlen_k, #
619849
"NUM_Q_HEADS": args.num_heads_q, "NUM_K_HEADS": args.num_heads_k, #
620850
"HEAD_SZ": args.head_size, #
621851
"BLOCK_M": args.block_m, "BLOCK_N": args.block_n, #
622-
"ATTN_FN": attn_fwd_pipelined_kernel if args.pipeline else attn_fwd_kernel
852+
"ATTN_FN": args.attention_type, #
623853
}
624854
print(config)
625855
run_attention(config)

0 commit comments

Comments
 (0)