Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def tlx_matmul_2cta(self, a, b, bias) -> Callable:
else:
return lambda: _tlx_matmul_2cta(a_contig, b_contig).to(target_dtype)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return (
Expand All @@ -685,7 +685,7 @@ def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable:
else:
return lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=True)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return (
Expand All @@ -695,21 +695,21 @@ def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable:
else:
return lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=False)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_warpspec_tma_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) + bias
else:
return lambda: blackwell_matmul_tma(a, b, warp_specialize=True)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_tma_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) + bias
else:
return lambda: blackwell_matmul_tma(a, b, warp_specialize=False)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return (
Expand All @@ -723,7 +723,7 @@ def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable:
a, b, warp_specialize=True
)

@register_benchmark(enabled=IS_BLACKWELL)
@register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER)
def triton_blackwell_descriptor_matmul(self, a, b, bias) -> Callable:
if bias is not None:
return (
Expand Down
72 changes: 48 additions & 24 deletions tritonbench/operators/gemm/warp_spec_persistent_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from tritonbench.utils.env_utils import is_tile_enabled
from tritonbench.utils.env_utils import IS_HOPPER, is_tile_enabled

from .triton_matmul_configs import get_tileir_configs

Expand Down Expand Up @@ -63,7 +63,9 @@ def _matmul_launch_metadata(grid, kernel, args):
default_s_range = small_stage_range
tma_persistent_s_range = small_stage_range
else:
bm_range = [128, 256]
# Hopper has smaller WGMMA tiles and tighter register budgets than
# Blackwell, so BLOCK_M=256 is rarely useful; trim the M range there.
bm_range = [64, 128] if IS_HOPPER else [128, 256]
bn_range = [128, 256]
bk_range = [64, 128]
default_s_range = [3, 4]
Expand Down Expand Up @@ -120,6 +122,16 @@ def _use_meta_ws():
return False


def _pingpong_options():
"""``pingpongAutoWS`` is a triton.Config autotune kwarg in the beta
branch (see test_tutorial09_warp_specialization.py for canonical
usage). It schedules producer/consumer warps in a ping-pong pattern
and is only useful on Hopper with the meta WS pipeline."""
if IS_HOPPER and _use_meta_ws():
return [True, False]
return [None]


def _prune_warp_specialize_configs(configs, named_args, **kwargs):
"""When warp specialization is enabled with the Meta WS pipeline,
only num_warps=4 is supported."""
Expand Down Expand Up @@ -320,12 +332,16 @@ def matmul_tma_persistent_get_configs(pre_hook=None):
)
return configs
else:
extra_kwargs = {}
base_extra_kwargs = {}
if _use_meta_ws():
extra_kwargs["early_tma_store_lowering"] = 1
extra_kwargs["maxRegAutoWS"] = 255
return [
triton.Config(
base_extra_kwargs["early_tma_store_lowering"] = 1
base_extra_kwargs["maxRegAutoWS"] = 255

def _make_config(BM, BN, BK, s, w, SUBTILE, FLATTEN, DP, pp):
extras = dict(base_extra_kwargs)
if pp is not None:
extras["pingpongAutoWS"] = pp
return triton.Config(
{
"BLOCK_SIZE_M": BM,
"BLOCK_SIZE_N": BN,
Expand All @@ -338,8 +354,11 @@ def matmul_tma_persistent_get_configs(pre_hook=None):
num_stages=s,
num_warps=w,
pre_hook=pre_hook,
**extra_kwargs,
) #
**extras,
)

return [
_make_config(BM, BN, BK, s, w, SUBTILE, FLATTEN, DP, pp)
for BM in bm_range #
for BN in bn_range #
for BK in bk_range #
Expand All @@ -348,38 +367,43 @@ def matmul_tma_persistent_get_configs(pre_hook=None):
for SUBTILE in [1, 2, 4] #
for FLATTEN in [True, False] #
for DP in [1, 2] #
for pp in _pingpong_options() #
]


def _prune_tma_persistent_configs(configs, named_args, **kwargs):
"""Prune configs for the TMA persistent kernel based on FLATTEN/WS rules.

- WARP_SPECIALIZE=False: require FLATTEN=False
- WARP_SPECIALIZE=True + meta WS: require FLATTEN=False
- WARP_SPECIALIZE=True + no meta WS: require FLATTEN=True
"""Prune configs for the TMA persistent kernel.

FLATTEN rules. FLATTEN=True is only supported on Blackwell with the
OAI Triton WS path.
- Hopper: require FLATTEN=False (regardless of WS / meta WS)
- Blackwell + WARP_SPECIALIZE=False: require FLATTEN=False
- Blackwell + WARP_SPECIALIZE=True + meta WS: require FLATTEN=False
- Blackwell + WARP_SPECIALIZE=True + OAI Triton WS: require FLATTEN=True

DATA_PARTITION_FACTOR rules. DP=2 is only valid on the meta WS path,
and the partitioned M tile must match the arch's largest usable
BLOCK_M (128 on Hopper, 256 on Blackwell).
"""
ws = kwargs.get("WARP_SPECIALIZE", False)
flatten_required = ws and not _use_meta_ws() and not IS_HOPPER
kept = []
for c in configs:
subtile = c.kwargs.get("EPILOGUE_SUBTILE", 1)
if c.kwargs.get("BLOCK_SIZE_N", 1) % subtile != 0:
continue
flatten = c.kwargs.get("FLATTEN", False)
if not ws:
if flatten:
continue
elif _use_meta_ws():
if flatten:
continue
else:
if not flatten:
continue
if flatten != flatten_required:
continue
data_partition_factor = c.kwargs.get("DATA_PARTITION_FACTOR", 1)
if data_partition_factor != 1:
if not _use_meta_ws():
continue
block_m = c.kwargs.get("BLOCK_SIZE_M", 1)
if block_m != 256:
# Data partitioning needs the largest M tile that still fits in
# registers: BLOCK_M=128 on Hopper, BLOCK_M=256 on Blackwell.
required_block_m = 128 if IS_HOPPER else 256
if block_m != required_block_m:
continue
kept.append(c)
return _prune_warp_specialize_configs(kept, named_args, **kwargs)
Expand Down
Loading