Skip to content

Commit f1d8930

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
Add Meta AutoWS support for triton_addmm
Reviewed By: manman-ren, rafaykhurram Differential Revision: D100376526 fbshipit-source-id: c3046c4e19afd58e96e11b5953cbb60b095ba2d6
1 parent adaafec commit f1d8930

1 file changed

Lines changed: 219 additions & 26 deletions

File tree

generative_recommenders/ops/triton/triton_addmm.py

Lines changed: 219 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@
5353
ENABLE_FULL_TURNING_SPACE = False
5454

5555

56+
def _use_meta_ws() -> bool:
57+
"""Check if Meta's warp specialization is available, enabled, and on SM100+."""
58+
return (
59+
is_sm100_plus()
60+
and hasattr(triton, "knobs")
61+
and hasattr(triton.knobs, "nvidia")
62+
and triton.knobs.nvidia.use_meta_ws
63+
)
64+
65+
5666
def _check_tma_alignment(
5767
x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, min_alignment: int = 16
5868
) -> bool:
@@ -83,6 +93,20 @@ def _check_tma_alignment(
8393
return (K % min_alignment == 0) and (N % min_alignment == 0)
8494

8595

96+
def _prune_persistent_autows_configs(configs, named_args, **kwargs): # noqa
97+
if not _use_meta_ws():
98+
return configs
99+
pruned = []
100+
for c in configs:
101+
BLOCK_M = c.kwargs.get("BLOCK_M", 0)
102+
DP = c.kwargs.get("DATA_PARTITION_FACTOR", 1)
103+
# DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
104+
if DP == 2 and BLOCK_M != 256:
105+
continue
106+
pruned.append(c)
107+
return pruned
108+
109+
86110
def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # noqa
87111
M = named_args.get("M", 0)
88112
N = named_args.get("N", 0)
@@ -407,6 +431,39 @@ def _get_addmm_tma_ws_persistent_configs(pre_hook=None) -> List[triton.Config]:
407431
return configs
408432

409433

434+
def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]:
435+
if not _use_meta_ws():
436+
configs = get_mm_configs(pre_hook=pre_hook)
437+
for c in configs:
438+
c.kwargs["DATA_PARTITION_FACTOR"] = 1
439+
c.kwargs["EPILOGUE_SUBTILE"] = 1
440+
return configs
441+
# TODO: Prune configs to best configs.
442+
return [
443+
triton.Config( # pyre-ignore[28]
444+
{
445+
"BLOCK_M": block_m,
446+
"BLOCK_N": block_n,
447+
"BLOCK_K": block_k,
448+
"GROUP_M": 8,
449+
"EPILOGUE_SUBTILE": subtile,
450+
"DATA_PARTITION_FACTOR": DP,
451+
},
452+
num_stages=num_stages,
453+
num_warps=4,
454+
pre_hook=pre_hook,
455+
early_tma_store_lowering=1,
456+
maxRegAutoWS=255,
457+
)
458+
for block_m in [64, 128, 256]
459+
for block_n in [64, 128, 256]
460+
for block_k in [64, 128, 256]
461+
for num_stages in [2, 3, 4]
462+
for subtile in [1, 2, 4]
463+
for DP in [1, 2]
464+
]
465+
466+
410467
@triton_cc(
411468
annotations={
412469
"M": "i32",
@@ -529,9 +586,104 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
529586
return pid_m, pid_n
530587

531588

589+
@triton.jit
590+
def _addmm_persistent_tile_body(
591+
x_desc,
592+
w_desc,
593+
y_desc,
594+
z_desc,
595+
tile_id,
596+
num_pid_in_group,
597+
num_pid_m,
598+
k_tiles,
599+
BLOCK_M: tl.constexpr,
600+
BLOCK_N: tl.constexpr,
601+
BLOCK_K: tl.constexpr,
602+
GROUP_M: tl.constexpr,
603+
ALLOW_TF32: tl.constexpr,
604+
BROADCAST_Y: tl.constexpr,
605+
NUM_SMS: tl.constexpr,
606+
EPILOGUE_SUBTILE: tl.constexpr,
607+
INNER_WARP_SPECIALIZE: tl.constexpr,
608+
):
609+
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS)
610+
offs_xm = pid_m * BLOCK_M
611+
offs_wn = pid_n * BLOCK_N
612+
613+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
614+
for k in tl.range(0, k_tiles, warp_specialize=INNER_WARP_SPECIALIZE):
615+
offs_k = k * BLOCK_K
616+
x = x_desc.load([offs_xm, offs_k])
617+
w = w_desc.load([offs_k, offs_wn])
618+
accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32)
619+
620+
# Epilogue subtiling breaks the store into multiple pieces to reduce
621+
# shared memory consumption and allow higher stage counts.
622+
if EPILOGUE_SUBTILE == 1:
623+
if BROADCAST_Y:
624+
y = y_desc.load([0, offs_wn])
625+
else:
626+
y = y_desc.load([offs_xm, offs_wn])
627+
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
628+
z_desc.store([offs_xm, offs_wn], z)
629+
elif EPILOGUE_SUBTILE == 2:
630+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
631+
acc = tl.permute(acc, (0, 2, 1))
632+
acc0, acc1 = tl.split(acc)
633+
if BROADCAST_Y:
634+
y0 = y_desc.load([0, offs_wn])
635+
else:
636+
y0 = y_desc.load([offs_xm, offs_wn])
637+
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
638+
z_desc.store([offs_xm, offs_wn], z0)
639+
if BROADCAST_Y:
640+
y1 = y_desc.load([0, offs_wn + BLOCK_N // 2])
641+
else:
642+
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 2])
643+
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
644+
z_desc.store([offs_xm, offs_wn + BLOCK_N // 2], z1)
645+
elif EPILOGUE_SUBTILE == 4:
646+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
647+
acc = tl.permute(acc, (0, 2, 1))
648+
acc_lo, acc_hi = tl.split(acc)
649+
acc_lo = tl.reshape(acc_lo, (BLOCK_M, 2, BLOCK_N // 4))
650+
acc_lo = tl.permute(acc_lo, (0, 2, 1))
651+
acc0, acc1 = tl.split(acc_lo)
652+
acc_hi = tl.reshape(acc_hi, (BLOCK_M, 2, BLOCK_N // 4))
653+
acc_hi = tl.permute(acc_hi, (0, 2, 1))
654+
acc2, acc3 = tl.split(acc_hi)
655+
if BROADCAST_Y:
656+
y0 = y_desc.load([0, offs_wn])
657+
else:
658+
y0 = y_desc.load([offs_xm, offs_wn])
659+
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
660+
z_desc.store([offs_xm, offs_wn], z0)
661+
if BROADCAST_Y:
662+
y1 = y_desc.load([0, offs_wn + BLOCK_N // 4])
663+
else:
664+
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 4])
665+
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
666+
z_desc.store([offs_xm, offs_wn + BLOCK_N // 4], z1)
667+
if BROADCAST_Y:
668+
y2 = y_desc.load([0, offs_wn + 2 * BLOCK_N // 4])
669+
else:
670+
y2 = y_desc.load([offs_xm, offs_wn + 2 * BLOCK_N // 4])
671+
z2 = (acc2 + y2.to(tl.float32)).to(z_desc.dtype)
672+
z_desc.store([offs_xm, offs_wn + 2 * BLOCK_N // 4], z2)
673+
if BROADCAST_Y:
674+
y3 = y_desc.load([0, offs_wn + 3 * BLOCK_N // 4])
675+
else:
676+
y3 = y_desc.load([offs_xm, offs_wn + 3 * BLOCK_N // 4])
677+
z3 = (acc3 + y3.to(tl.float32)).to(z_desc.dtype)
678+
z_desc.store([offs_xm, offs_wn + 3 * BLOCK_N // 4], z3)
679+
else:
680+
tl.static_assert(False, "Unsupported EPILOGUE_SUBTILE value")
681+
682+
532683
@triton_autotune(
533-
configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook),
534-
key=["N", "K", "WARP_SPECIALIZE"],
684+
configs=get_triton_persistent_configs(pre_hook=_addmm_tma_set_block_size_hook),
685+
key=["M", "N", "K", "WARP_SPECIALIZE"],
686+
prune_configs_by={"early_config_prune": _prune_persistent_autows_configs},
535687
)
536688
@triton.jit
537689
def _addmm_fwd_tma_persistent(
@@ -550,6 +702,9 @@ def _addmm_fwd_tma_persistent(
550702
BROADCAST_Y: tl.constexpr,
551703
WARP_SPECIALIZE: tl.constexpr,
552704
NUM_SMS: tl.constexpr,
705+
EPILOGUE_SUBTILE: tl.constexpr,
706+
DATA_PARTITION_FACTOR: tl.constexpr,
707+
USE_META_WS: tl.constexpr,
553708
):
554709
start_pid = tl.program_id(axis=0)
555710
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -559,27 +714,61 @@ def _addmm_fwd_tma_persistent(
559714

560715
num_pid_in_group = GROUP_M * num_pid_n
561716

562-
for tile_id in tl.range(
563-
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE
564-
):
565-
pid_m, pid_n = _compute_pid(
566-
tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS
567-
)
568-
offs_xm = pid_m * BLOCK_M
569-
offs_wn = pid_n * BLOCK_N
570-
571-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
572-
for k in tl.range(0, k_tiles, warp_specialize=WARP_SPECIALIZE):
573-
offs_k = k * BLOCK_K
574-
x = x_desc.load([offs_xm, offs_k])
575-
w = w_desc.load([offs_k, offs_wn])
576-
accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32)
577-
if BROADCAST_Y:
578-
y = y_desc.load([0, offs_wn])
579-
else:
580-
y = y_desc.load([offs_xm, offs_wn])
581-
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
582-
z_desc.store([offs_xm, offs_wn], z)
717+
if USE_META_WS:
718+
# Some arguments are only available in FBexperimental.
719+
# pyre-ignore[28]: smem_alloc_algo is FBexperimental
720+
for tile_id in tl.range(
721+
start_pid,
722+
num_tiles,
723+
NUM_SMS,
724+
flatten=False,
725+
warp_specialize=WARP_SPECIALIZE,
726+
data_partition_factor=DATA_PARTITION_FACTOR,
727+
smem_alloc_algo=1,
728+
):
729+
_addmm_persistent_tile_body(
730+
x_desc,
731+
w_desc,
732+
y_desc,
733+
z_desc,
734+
tile_id,
735+
num_pid_in_group,
736+
num_pid_m,
737+
k_tiles,
738+
BLOCK_M=BLOCK_M,
739+
BLOCK_N=BLOCK_N,
740+
BLOCK_K=BLOCK_K,
741+
GROUP_M=GROUP_M,
742+
ALLOW_TF32=ALLOW_TF32,
743+
BROADCAST_Y=BROADCAST_Y,
744+
NUM_SMS=NUM_SMS,
745+
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
746+
INNER_WARP_SPECIALIZE=tl.constexpr(False),
747+
)
748+
else:
749+
# Pure OAI Triton version.
750+
for tile_id in tl.range(
751+
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE
752+
):
753+
_addmm_persistent_tile_body(
754+
x_desc,
755+
w_desc,
756+
y_desc,
757+
z_desc,
758+
tile_id,
759+
num_pid_in_group,
760+
num_pid_m,
761+
k_tiles,
762+
BLOCK_M=BLOCK_M,
763+
BLOCK_N=BLOCK_N,
764+
BLOCK_K=BLOCK_K,
765+
GROUP_M=GROUP_M,
766+
ALLOW_TF32=ALLOW_TF32,
767+
BROADCAST_Y=BROADCAST_Y,
768+
NUM_SMS=NUM_SMS,
769+
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
770+
INNER_WARP_SPECIALIZE=WARP_SPECIALIZE,
771+
)
583772

584773

585774
@triton_autotune(
@@ -894,7 +1083,7 @@ def _addmm_fwd_tma_ws_persistent(
8941083
z = (result + y.to(tl.float32)).to(z_desc.dtype)
8951084
z_buf_view = tlx.local_view(z_buffers, z_idx)
8961085
# If Y and Z are not shared wait for Z to be empty.
897-
# If there are shared this already guarenteed.
1086+
# If there are shared this already guaranteed.
8981087
if not Y_Z_SHARED:
8991088
z_empty = tlx.local_view(z_empty_bars, z_idx)
9001089
tlx.barrier_wait(z_empty, z_load_phase ^ 1)
@@ -1228,8 +1417,12 @@ def triton_addmm_fwd_tma_persistent(
12281417
x: torch.Tensor,
12291418
w: torch.Tensor,
12301419
y: torch.Tensor,
1231-
warp_specialize: bool = False,
1420+
warp_specialize: bool | None = None,
12321421
) -> torch.Tensor:
1422+
_meta_ws = _use_meta_ws()
1423+
if warp_specialize is None:
1424+
warp_specialize = _meta_ws
1425+
12331426
M, K = x.shape
12341427
_, N = w.shape
12351428

@@ -1258,7 +1451,6 @@ def triton_addmm_fwd_tma_persistent(
12581451
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
12591452

12601453
def grid(meta):
1261-
nonlocal x_desc, w_desc, z_desc
12621454
BLOCK_M = meta["BLOCK_M"]
12631455
BLOCK_N = meta["BLOCK_N"]
12641456
return (
@@ -1280,6 +1472,7 @@ def grid(meta):
12801472
BROADCAST_Y=is_y_1d,
12811473
WARP_SPECIALIZE=warp_specialize,
12821474
NUM_SMS=NUM_SMS,
1475+
USE_META_WS=_meta_ws,
12831476
)
12841477
return z
12851478

0 commit comments

Comments
 (0)