Skip to content

Commit 9cda5f9

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add EPILOGUE_SUBTILE=8 and generalize TritonBench addmm kernel
Differential Revision: D106049038
1 parent 27c062c commit 9cda5f9

1 file changed

Lines changed: 22 additions & 59 deletions

File tree

generative_recommenders/ops/triton/triton_addmm.py

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import triton.language as tl
3030
from generative_recommenders.ops.utils import is_sm100_plus, maybe_register_custom_op
3131

32+
# @manual=//triton:triton
33+
from triton.language.extra.subtile_ops import _split_n_2D # pyre-ignore[16]
34+
3235
try:
3336
# @manual=//triton:triton
3437
import triton.language.extra.tlx as tlx # type: ignore
@@ -96,13 +99,20 @@ def _check_tma_alignment(
9699
def _prune_persistent_autows_configs(configs, named_args, **kwargs): # noqa
97100
if not _use_meta_ws():
98101
return configs
102+
BROADCAST_Y = kwargs.get("BROADCAST_Y", False)
99103
pruned = []
100104
for c in configs:
101105
BLOCK_M = c.kwargs.get("BLOCK_M", 0)
106+
BLOCK_N = c.kwargs.get("BLOCK_N", 0)
107+
EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1)
102108
DP = c.kwargs.get("DATA_PARTITION_FACTOR", 1)
103109
# DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
104110
if DP == 2 and BLOCK_M != 256:
105111
continue
112+
if (BLOCK_N // EPILOGUE_SUBTILE) < 32:
113+
continue
114+
if BROADCAST_Y and (BLOCK_N // EPILOGUE_SUBTILE) < 64:
115+
continue
106116
pruned.append(c)
107117
return pruned
108118

@@ -120,7 +130,6 @@ def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # n
120130
NUM_MMA_GROUPS = c.kwargs.get("NUM_MMA_GROUPS", 1)
121131
BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
122132
NUM_SMEM_BUFFERS = c.kwargs.get("NUM_SMEM_BUFFERS", 1)
123-
EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1)
124133

125134
# Hardware constraint: Always make MMA tile 128.
126135
if BLOCK_M_SPLIT != 128:
@@ -459,7 +468,7 @@ def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]:
459468
for block_n in [64, 128, 256]
460469
for block_k in [64, 128, 256]
461470
for num_stages in [2, 3, 4]
462-
for subtile in [1, 2, 4]
471+
for subtile in [1, 2, 4, 8]
463472
for DP in [1, 2]
464473
]
465474

@@ -619,65 +628,19 @@ def _addmm_persistent_tile_body(
619628

620629
# Epilogue subtiling breaks the store into multiple pieces to reduce
621630
# shared memory consumption and allow higher stage counts.
622-
if EPILOGUE_SUBTILE == 1:
623-
if BROADCAST_Y:
624-
y = y_desc.load([0, offs_wn])
625-
else:
626-
y = y_desc.load([offs_xm, offs_wn])
627-
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
628-
z_desc.store([offs_xm, offs_wn], z)
629-
elif EPILOGUE_SUBTILE == 2:
630-
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
631-
acc = tl.permute(acc, (0, 2, 1))
632-
acc0, acc1 = tl.split(acc)
633-
if BROADCAST_Y:
634-
y0 = y_desc.load([0, offs_wn])
635-
else:
636-
y0 = y_desc.load([offs_xm, offs_wn])
637-
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
638-
z_desc.store([offs_xm, offs_wn], z0)
639-
if BROADCAST_Y:
640-
y1 = y_desc.load([0, offs_wn + BLOCK_N // 2])
641-
else:
642-
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 2])
643-
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
644-
z_desc.store([offs_xm, offs_wn + BLOCK_N // 2], z1)
645-
elif EPILOGUE_SUBTILE == 4:
646-
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
647-
acc = tl.permute(acc, (0, 2, 1))
648-
acc_lo, acc_hi = tl.split(acc)
649-
acc_lo = tl.reshape(acc_lo, (BLOCK_M, 2, BLOCK_N // 4))
650-
acc_lo = tl.permute(acc_lo, (0, 2, 1))
651-
acc0, acc1 = tl.split(acc_lo)
652-
acc_hi = tl.reshape(acc_hi, (BLOCK_M, 2, BLOCK_N // 4))
653-
acc_hi = tl.permute(acc_hi, (0, 2, 1))
654-
acc2, acc3 = tl.split(acc_hi)
655-
if BROADCAST_Y:
656-
y0 = y_desc.load([0, offs_wn])
657-
else:
658-
y0 = y_desc.load([offs_xm, offs_wn])
659-
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
660-
z_desc.store([offs_xm, offs_wn], z0)
661-
if BROADCAST_Y:
662-
y1 = y_desc.load([0, offs_wn + BLOCK_N // 4])
663-
else:
664-
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 4])
665-
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
666-
z_desc.store([offs_xm, offs_wn + BLOCK_N // 4], z1)
667-
if BROADCAST_Y:
668-
y2 = y_desc.load([0, offs_wn + 2 * BLOCK_N // 4])
669-
else:
670-
y2 = y_desc.load([offs_xm, offs_wn + 2 * BLOCK_N // 4])
671-
z2 = (acc2 + y2.to(tl.float32)).to(z_desc.dtype)
672-
z_desc.store([offs_xm, offs_wn + 2 * BLOCK_N // 4], z2)
631+
tl.static_assert(
632+
EPILOGUE_SUBTILE <= 8,
633+
"EPILOGUE_SUBTILE > 8 is not supported",
634+
)
635+
acc_subtiles = _split_n_2D(accumulator, EPILOGUE_SUBTILE) # pyre-ignore[16]
636+
slice_size: tl.constexpr = BLOCK_N // EPILOGUE_SUBTILE
637+
for i in tl.static_range(EPILOGUE_SUBTILE):
673638
if BROADCAST_Y:
674-
y3 = y_desc.load([0, offs_wn + 3 * BLOCK_N // 4])
639+
y_i = y_desc.load([0, offs_wn + i * slice_size])
675640
else:
676-
y3 = y_desc.load([offs_xm, offs_wn + 3 * BLOCK_N // 4])
677-
z3 = (acc3 + y3.to(tl.float32)).to(z_desc.dtype)
678-
z_desc.store([offs_xm, offs_wn + 3 * BLOCK_N // 4], z3)
679-
else:
680-
tl.static_assert(False, "Unsupported EPILOGUE_SUBTILE value")
641+
y_i = y_desc.load([offs_xm, offs_wn + i * slice_size])
642+
z_i = (acc_subtiles[i] + y_i.to(tl.float32)).to(z_desc.dtype)
643+
z_desc.store([offs_xm, offs_wn + i * slice_size], z_i)
681644

682645

683646
@triton_autotune(

0 commit comments

Comments
 (0)