Skip to content

Commit dc83461

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add EPILOGUE_SUBTILE=8 and generalize TritonBench addmm kernel
Reviewed By: njriasan, rafaykhurram Differential Revision: D106049038
1 parent 9959d2c commit dc83461

1 file changed

Lines changed: 25 additions & 59 deletions

File tree

generative_recommenders/ops/triton/triton_addmm.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
import triton.language as tl
3030
from generative_recommenders.ops.utils import is_sm100_plus, maybe_register_custom_op
3131

32+
try:
33+
# @manual=//triton:triton
34+
from triton.language.extra.subtile_ops import _split_n_2D
35+
except ImportError:
36+
_split_n_2D = None
37+
3238
try:
3339
# @manual=//triton:triton
3440
import triton.language.extra.tlx as tlx # type: ignore
@@ -96,13 +102,20 @@ def _check_tma_alignment(
96102
def _prune_persistent_autows_configs(configs, named_args, **kwargs): # noqa
97103
if not _use_meta_ws():
98104
return configs
105+
BROADCAST_Y = kwargs.get("BROADCAST_Y", False)
99106
pruned = []
100107
for c in configs:
101108
BLOCK_M = c.kwargs.get("BLOCK_M", 0)
109+
BLOCK_N = c.kwargs.get("BLOCK_N", 0)
110+
EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1)
102111
DP = c.kwargs.get("DATA_PARTITION_FACTOR", 1)
103112
# DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
104113
if DP == 2 and BLOCK_M != 256:
105114
continue
115+
if (BLOCK_N // EPILOGUE_SUBTILE) < 32:
116+
continue
117+
if BROADCAST_Y and (BLOCK_N // EPILOGUE_SUBTILE) < 64:
118+
continue
106119
pruned.append(c)
107120
return pruned
108121

@@ -120,7 +133,6 @@ def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # n
120133
NUM_MMA_GROUPS = c.kwargs.get("NUM_MMA_GROUPS", 1)
121134
BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
122135
NUM_SMEM_BUFFERS = c.kwargs.get("NUM_SMEM_BUFFERS", 1)
123-
EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1)
124136

125137
# Hardware constraint: Always make MMA tile 128.
126138
if BLOCK_M_SPLIT != 128:
@@ -459,7 +471,7 @@ def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]:
459471
for block_n in [64, 128, 256]
460472
for block_k in [64, 128, 256]
461473
for num_stages in [2, 3, 4]
462-
for subtile in [1, 2, 4]
474+
for subtile in [1, 2, 4, 8]
463475
for DP in [1, 2]
464476
]
465477

@@ -619,65 +631,19 @@ def _addmm_persistent_tile_body(
619631

620632
# Epilogue subtiling breaks the store into multiple pieces to reduce
621633
# shared memory consumption and allow higher stage counts.
622-
if EPILOGUE_SUBTILE == 1:
623-
if BROADCAST_Y:
624-
y = y_desc.load([0, offs_wn])
625-
else:
626-
y = y_desc.load([offs_xm, offs_wn])
627-
z = (accumulator + y.to(tl.float32)).to(z_desc.dtype)
628-
z_desc.store([offs_xm, offs_wn], z)
629-
elif EPILOGUE_SUBTILE == 2:
630-
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
631-
acc = tl.permute(acc, (0, 2, 1))
632-
acc0, acc1 = tl.split(acc)
633-
if BROADCAST_Y:
634-
y0 = y_desc.load([0, offs_wn])
635-
else:
636-
y0 = y_desc.load([offs_xm, offs_wn])
637-
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
638-
z_desc.store([offs_xm, offs_wn], z0)
639-
if BROADCAST_Y:
640-
y1 = y_desc.load([0, offs_wn + BLOCK_N // 2])
641-
else:
642-
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 2])
643-
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
644-
z_desc.store([offs_xm, offs_wn + BLOCK_N // 2], z1)
645-
elif EPILOGUE_SUBTILE == 4:
646-
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
647-
acc = tl.permute(acc, (0, 2, 1))
648-
acc_lo, acc_hi = tl.split(acc)
649-
acc_lo = tl.reshape(acc_lo, (BLOCK_M, 2, BLOCK_N // 4))
650-
acc_lo = tl.permute(acc_lo, (0, 2, 1))
651-
acc0, acc1 = tl.split(acc_lo)
652-
acc_hi = tl.reshape(acc_hi, (BLOCK_M, 2, BLOCK_N // 4))
653-
acc_hi = tl.permute(acc_hi, (0, 2, 1))
654-
acc2, acc3 = tl.split(acc_hi)
655-
if BROADCAST_Y:
656-
y0 = y_desc.load([0, offs_wn])
657-
else:
658-
y0 = y_desc.load([offs_xm, offs_wn])
659-
z0 = (acc0 + y0.to(tl.float32)).to(z_desc.dtype)
660-
z_desc.store([offs_xm, offs_wn], z0)
661-
if BROADCAST_Y:
662-
y1 = y_desc.load([0, offs_wn + BLOCK_N // 4])
663-
else:
664-
y1 = y_desc.load([offs_xm, offs_wn + BLOCK_N // 4])
665-
z1 = (acc1 + y1.to(tl.float32)).to(z_desc.dtype)
666-
z_desc.store([offs_xm, offs_wn + BLOCK_N // 4], z1)
667-
if BROADCAST_Y:
668-
y2 = y_desc.load([0, offs_wn + 2 * BLOCK_N // 4])
669-
else:
670-
y2 = y_desc.load([offs_xm, offs_wn + 2 * BLOCK_N // 4])
671-
z2 = (acc2 + y2.to(tl.float32)).to(z_desc.dtype)
672-
z_desc.store([offs_xm, offs_wn + 2 * BLOCK_N // 4], z2)
634+
tl.static_assert(
635+
EPILOGUE_SUBTILE <= 8,
636+
"EPILOGUE_SUBTILE > 8 is not supported",
637+
)
638+
acc_subtiles = _split_n_2D(accumulator, EPILOGUE_SUBTILE) # pyre-ignore[16]
639+
slice_size: tl.constexpr = BLOCK_N // EPILOGUE_SUBTILE
640+
for i in tl.static_range(EPILOGUE_SUBTILE):
673641
if BROADCAST_Y:
674-
y3 = y_desc.load([0, offs_wn + 3 * BLOCK_N // 4])
642+
y_i = y_desc.load([0, offs_wn + i * slice_size])
675643
else:
676-
y3 = y_desc.load([offs_xm, offs_wn + 3 * BLOCK_N // 4])
677-
z3 = (acc3 + y3.to(tl.float32)).to(z_desc.dtype)
678-
z_desc.store([offs_xm, offs_wn + 3 * BLOCK_N // 4], z3)
679-
else:
680-
tl.static_assert(False, "Unsupported EPILOGUE_SUBTILE value")
644+
y_i = y_desc.load([offs_xm, offs_wn + i * slice_size])
645+
z_i = (acc_subtiles[i] + y_i.to(tl.float32)).to(z_desc.dtype)
646+
z_desc.store([offs_xm, offs_wn + i * slice_size], z_i)
681647

682648

683649
@triton_autotune(

0 commit comments

Comments
 (0)