Skip to content

Commit 8194d5c

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
Add Meta AutoWS support for triton_addmm
Reviewed By: manman-ren Differential Revision: D100376526
1 parent 5871c59 commit 8194d5c

1 file changed

Lines changed: 219 additions & 26 deletions

File tree

generative_recommenders/ops/triton/triton_addmm.py

Lines changed: 219 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@
5151
ENABLE_FULL_TURNING_SPACE = False
5252

5353

54+
def _use_meta_ws() -> bool:
55+
"""Check if Meta's warp specialization is available, enabled, and on SM100+."""
56+
return (
57+
is_sm100_plus()
58+
and hasattr(triton, "knobs")
59+
and hasattr(triton.knobs, "nvidia")
60+
and triton.knobs.nvidia.use_meta_ws
61+
)
62+
63+
5464
def _check_tma_alignment(
5565
x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, min_alignment: int = 16
5666
) -> bool:
@@ -81,6 +91,20 @@ def _check_tma_alignment(
8191
return (K % min_alignment == 0) and (N % min_alignment == 0)
8292

8393

94+
def _prune_persistent_autows_configs(configs, named_args, **kwargs): # noqa
95+
if not _use_meta_ws():
96+
return configs
97+
pruned = []
98+
for c in configs:
99+
BLOCK_M = c.kwargs.get("BLOCK_M", 0)
100+
DP = c.kwargs.get("DATA_PARTITION_FACTOR", 1)
101+
# DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
102+
if DP == 2 and BLOCK_M != 256:
103+
continue
104+
pruned.append(c)
105+
return pruned
106+
107+
84108
def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # noqa
85109
M = named_args.get("M", 0)
86110
N = named_args.get("N", 0)
@@ -405,6 +429,39 @@ def _get_addmm_tma_ws_persistent_configs(pre_hook=None) -> List[triton.Config]:
405429
return configs
406430

407431

432+
def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]:
433+
if not _use_meta_ws():
434+
configs = get_mm_configs(pre_hook=pre_hook)
435+
for c in configs:
436+
c.kwargs["DATA_PARTITION_FACTOR"] = 1
437+
c.kwargs["EPILOGUE_SUBTILE"] = 1
438+
return configs
439+
# TODO: Prune configs to best configs.
440+
return [
441+
triton.Config( # pyre-ignore[28]
442+
{
443+
"BLOCK_M": block_m,
444+
"BLOCK_N": block_n,
445+
"BLOCK_K": block_k,
446+
"GROUP_M": 8,
447+
"EPILOGUE_SUBTILE": subtile,
448+
"DATA_PARTITION_FACTOR": DP,
449+
},
450+
num_stages=num_stages,
451+
num_warps=4,
452+
pre_hook=pre_hook,
453+
early_tma_store_lowering=1,
454+
maxRegAutoWS=255,
455+
)
456+
for block_m in [64, 128, 256]
457+
for block_n in [64, 128, 256]
458+
for block_k in [64, 128, 256]
459+
for num_stages in [2, 3, 4]
460+
for subtile in [1, 2, 4]
461+
for DP in [1, 2]
462+
]
463+
464+
408465
@triton_cc(
409466
annotations={
410467
"M": "i32",
@@ -527,9 +584,104 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
527584
return pid_m, pid_n
528585

529586

587+
@triton.jit
588+
def _addmm_persistent_tile_body(
589+
x_desc,
590+
w_desc,
591+
y_desc,
592+
z_desc,
593+
tile_id,
594+
num_pid_in_group,
595+
num_pid_m,
596+
k_tiles,
597+
BLOCK_M: tl.constexpr,
598+
BLOCK_N: tl.constexpr,
599+
BLOCK_K: tl.constexpr,
600+
GROUP_M: tl.constexpr,
601+
ALLOW_TF32: tl.constexpr,
602+
BROADCAST_Y: tl.constexpr,
603+
NUM_SMS: tl.constexpr,
604+
EPILOGUE_SUBTILE: tl.constexpr,
605+
INNER_WARP_SPECIALIZE: tl.constexpr,
606+
):
607+
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS)
608+
offs_xm = pid_m * BLOCK_M
609+
offs_wn = pid_n * BLOCK_N
610+
611+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
612+
for k in tl.range(0, k_tiles, warp_specialize=INNER_WARP_SPECIALIZE):
613+
offs_k = k * BLOCK_K
614+
x = x_desc.load([offs_xm, offs_k])
615+
w = w_desc.load([offs_k, offs_wn])
616+
accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32)
617+
618+
# Epilogue subtiling breaks the store into multiple pieces to reduce
619+
# shared memory consumption and allow higher stage counts.
620+
if EPILOGUE_SUBTILE == 1:
621+
if BROADCAST_Y:
622+
y = y_desc.load([0, offs_wn])
623+
else:
624+
y = y_desc.load([offs_xm, offs_wn])
625+
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
626+
z_desc.store([offs_xm, offs_wn], z)
627+
elif EPILOGUE_SUBTILE == 2:
628+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
629+
acc = tl.permute(acc, (0, 2, 1))
630+
acc0, acc1 = tl.split(acc)
631+
if BROADCAST_Y:
632+
y0 = y_desc.load([0, offs_wn])
633+
else:
634+
y0 = y_desc.load([offs_xm, offs_wn])
635+
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
636+
z_desc.store([offs_xm, offs_wn], z0)
637+
if BROADCAST_Y:
638+
y1 = y_desc.load([0, offs_wn + BLOCK_N // 2])
639+
else:
640+
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 2])
641+
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
642+
z_desc.store([offs_xm, offs_wn + BLOCK_N // 2], z1)
643+
elif EPILOGUE_SUBTILE == 4:
644+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
645+
acc = tl.permute(acc, (0, 2, 1))
646+
acc_lo, acc_hi = tl.split(acc)
647+
acc_lo = tl.reshape(acc_lo, (BLOCK_M, 2, BLOCK_N // 4))
648+
acc_lo = tl.permute(acc_lo, (0, 2, 1))
649+
acc0, acc1 = tl.split(acc_lo)
650+
acc_hi = tl.reshape(acc_hi, (BLOCK_M, 2, BLOCK_N // 4))
651+
acc_hi = tl.permute(acc_hi, (0, 2, 1))
652+
acc2, acc3 = tl.split(acc_hi)
653+
if BROADCAST_Y:
654+
y0 = y_desc.load([0, offs_wn])
655+
else:
656+
y0 = y_desc.load([offs_xm, offs_wn])
657+
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
658+
z_desc.store([offs_xm, offs_wn], z0)
659+
if BROADCAST_Y:
660+
y1 = y_desc.load([0, offs_wn + BLOCK_N // 4])
661+
else:
662+
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 4])
663+
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
664+
z_desc.store([offs_xm, offs_wn + BLOCK_N // 4], z1)
665+
if BROADCAST_Y:
666+
y2 = y_desc.load([0, offs_wn + 2 * BLOCK_N // 4])
667+
else:
668+
y2 = y_desc.load([offs_xm, offs_wn + 2 * BLOCK_N // 4])
669+
z2 = (acc2 + y2.to(tl.float32)).to(z_desc.dtype)
670+
z_desc.store([offs_xm, offs_wn + 2 * BLOCK_N // 4], z2)
671+
if BROADCAST_Y:
672+
y3 = y_desc.load([0, offs_wn + 3 * BLOCK_N // 4])
673+
else:
674+
y3 = y_desc.load([offs_xm, offs_wn + 3 * BLOCK_N // 4])
675+
z3 = (acc3 + y3.to(tl.float32)).to(z_desc.dtype)
676+
z_desc.store([offs_xm, offs_wn + 3 * BLOCK_N // 4], z3)
677+
else:
678+
tl.static_assert(False, "Unsupported EPILOGUE_SUBTILE value")
679+
680+
530681
@triton_autotune(
531-
configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook),
532-
key=["N", "K", "WARP_SPECIALIZE"],
682+
configs=get_triton_persistent_configs(pre_hook=_addmm_tma_set_block_size_hook),
683+
key=["M", "N", "K", "WARP_SPECIALIZE"],
684+
prune_configs_by={"early_config_prune": _prune_persistent_autows_configs},
533685
)
534686
@triton.jit
535687
def _addmm_fwd_tma_persistent(
@@ -548,6 +700,9 @@ def _addmm_fwd_tma_persistent(
548700
BROADCAST_Y: tl.constexpr,
549701
WARP_SPECIALIZE: tl.constexpr,
550702
NUM_SMS: tl.constexpr,
703+
EPILOGUE_SUBTILE: tl.constexpr,
704+
DATA_PARTITION_FACTOR: tl.constexpr,
705+
USE_META_WS: tl.constexpr,
551706
):
552707
start_pid = tl.program_id(axis=0)
553708
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -557,27 +712,61 @@ def _addmm_fwd_tma_persistent(
557712

558713
num_pid_in_group = GROUP_M * num_pid_n
559714

560-
for tile_id in tl.range(
561-
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE
562-
):
563-
pid_m, pid_n = _compute_pid(
564-
tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS
565-
)
566-
offs_xm = pid_m * BLOCK_M
567-
offs_wn = pid_n * BLOCK_N
568-
569-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
570-
for k in tl.range(0, k_tiles, warp_specialize=WARP_SPECIALIZE):
571-
offs_k = k * BLOCK_K
572-
x = x_desc.load([offs_xm, offs_k])
573-
w = w_desc.load([offs_k, offs_wn])
574-
accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32)
575-
if BROADCAST_Y:
576-
y = y_desc.load([0, offs_wn])
577-
else:
578-
y = y_desc.load([offs_xm, offs_wn])
579-
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
580-
z_desc.store([offs_xm, offs_wn], z)
715+
if USE_META_WS:
716+
# Some arguments are only available in FBexperimental.
717+
# pyre-ignore[28]: smem_alloc_algo is FBexperimental
718+
for tile_id in tl.range(
719+
start_pid,
720+
num_tiles,
721+
NUM_SMS,
722+
flatten=False,
723+
warp_specialize=WARP_SPECIALIZE,
724+
data_partition_factor=DATA_PARTITION_FACTOR,
725+
smem_alloc_algo=1,
726+
):
727+
_addmm_persistent_tile_body(
728+
x_desc,
729+
w_desc,
730+
y_desc,
731+
z_desc,
732+
tile_id,
733+
num_pid_in_group,
734+
num_pid_m,
735+
k_tiles,
736+
BLOCK_M=BLOCK_M,
737+
BLOCK_N=BLOCK_N,
738+
BLOCK_K=BLOCK_K,
739+
GROUP_M=GROUP_M,
740+
ALLOW_TF32=ALLOW_TF32,
741+
BROADCAST_Y=BROADCAST_Y,
742+
NUM_SMS=NUM_SMS,
743+
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
744+
INNER_WARP_SPECIALIZE=tl.constexpr(False),
745+
)
746+
else:
747+
# Pure OAI Triton version.
748+
for tile_id in tl.range(
749+
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE
750+
):
751+
_addmm_persistent_tile_body(
752+
x_desc,
753+
w_desc,
754+
y_desc,
755+
z_desc,
756+
tile_id,
757+
num_pid_in_group,
758+
num_pid_m,
759+
k_tiles,
760+
BLOCK_M=BLOCK_M,
761+
BLOCK_N=BLOCK_N,
762+
BLOCK_K=BLOCK_K,
763+
GROUP_M=GROUP_M,
764+
ALLOW_TF32=ALLOW_TF32,
765+
BROADCAST_Y=BROADCAST_Y,
766+
NUM_SMS=NUM_SMS,
767+
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
768+
INNER_WARP_SPECIALIZE=WARP_SPECIALIZE,
769+
)
581770

582771

583772
@triton_autotune(
@@ -892,7 +1081,7 @@ def _addmm_fwd_tma_ws_persistent(
8921081
z = (result + y.to(tl.float32)).to(z_desc.dtype)
8931082
z_buf_view = tlx.local_view(z_buffers, z_idx)
8941083
# If Y and Z are not shared wait for Z to be empty.
895-
# If there are shared this already guarenteed.
1084+
# If there are shared this already guaranteed.
8961085
if not Y_Z_SHARED:
8971086
z_empty = tlx.local_view(z_empty_bars, z_idx)
8981087
tlx.barrier_wait(z_empty, z_load_phase ^ 1)
@@ -1226,8 +1415,12 @@ def triton_addmm_fwd_tma_persistent(
12261415
x: torch.Tensor,
12271416
w: torch.Tensor,
12281417
y: torch.Tensor,
1229-
warp_specialize: bool = False,
1418+
warp_specialize: bool | None = None,
12301419
) -> torch.Tensor:
1420+
_meta_ws = _use_meta_ws()
1421+
if warp_specialize is None:
1422+
warp_specialize = _meta_ws
1423+
12311424
M, K = x.shape
12321425
_, N = w.shape
12331426

@@ -1256,7 +1449,6 @@ def triton_addmm_fwd_tma_persistent(
12561449
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
12571450

12581451
def grid(meta):
1259-
nonlocal x_desc, w_desc, z_desc
12601452
BLOCK_M = meta["BLOCK_M"]
12611453
BLOCK_N = meta["BLOCK_N"]
12621454
return (
@@ -1278,6 +1470,7 @@ def grid(meta):
12781470
BROADCAST_Y=is_y_1d,
12791471
WARP_SPECIALIZE=warp_specialize,
12801472
NUM_SMS=NUM_SMS,
1473+
USE_META_WS=_meta_ws,
12811474
)
12821475
return z
12831476

0 commit comments

Comments
 (0)