diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index f16f1149b..8b39e1afe 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -675,7 +675,7 @@ def tlx_matmul_2cta(self, a, b, bias) -> Callable: else: return lambda: _tlx_matmul_2cta(a_contig, b_contig).to(target_dtype) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable: if bias is not None: return ( @@ -685,7 +685,7 @@ def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=True) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable: if bias is not None: return ( @@ -695,21 +695,21 @@ def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=False) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_warpspec_tma_matmul(self, a, b, bias) -> Callable: if bias is not None: return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) + bias else: return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_tma_matmul(self, a, b, bias) -> Callable: if bias is not None: return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) + bias else: return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable: if bias is not None: return ( @@ -723,7 +723,7 @@ def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable: a, b, warp_specialize=True ) - @register_benchmark(enabled=IS_BLACKWELL) + @register_benchmark(enabled=IS_BLACKWELL or IS_HOPPER) def triton_blackwell_descriptor_matmul(self, a, b, bias) -> Callable: if bias is not None: return ( diff --git a/tritonbench/operators/gemm/warp_spec_persistent_matmul.py b/tritonbench/operators/gemm/warp_spec_persistent_matmul.py index d2df22615..5fcdca7f8 100644 --- a/tritonbench/operators/gemm/warp_spec_persistent_matmul.py +++ b/tritonbench/operators/gemm/warp_spec_persistent_matmul.py @@ -10,7 +10,7 @@ import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor -from tritonbench.utils.env_utils import is_tile_enabled +from tritonbench.utils.env_utils import IS_HOPPER, is_tile_enabled from .triton_matmul_configs import get_tileir_configs @@ -63,7 +63,9 @@ def _matmul_launch_metadata(grid, kernel, args): default_s_range = small_stage_range tma_persistent_s_range = small_stage_range else: - bm_range = [128, 256] + # Hopper has smaller WGMMA tiles and tighter register budgets than + # Blackwell, so BLOCK_M=256 is rarely useful; trim the M range there. + bm_range = [64, 128] if IS_HOPPER else [128, 256] bn_range = [128, 256] bk_range = [64, 128] default_s_range = [3, 4] @@ -120,6 +122,16 @@ def _use_meta_ws(): return False +def _pingpong_options(): + """``pingpongAutoWS`` is a triton.Config autotune kwarg in the beta + branch (see test_tutorial09_warp_specialization.py for canonical + usage). It schedules producer/consumer warps in a ping-pong pattern + and is only useful on Hopper with the meta WS pipeline.""" + if IS_HOPPER and _use_meta_ws(): + return [True, False] + return [None] + + def _prune_warp_specialize_configs(configs, named_args, **kwargs): """When warp specialization is enabled with the Meta WS pipeline, only num_warps=4 is supported.""" @@ -320,12 +332,16 @@ def matmul_tma_persistent_get_configs(pre_hook=None): ) return configs else: - extra_kwargs = {} + base_extra_kwargs = {} if _use_meta_ws(): - extra_kwargs["early_tma_store_lowering"] = 1 - extra_kwargs["maxRegAutoWS"] = 255 - return [ - triton.Config( + base_extra_kwargs["early_tma_store_lowering"] = 1 + base_extra_kwargs["maxRegAutoWS"] = 255 + + def _make_config(BM, BN, BK, s, w, SUBTILE, FLATTEN, DP, pp): + extras = dict(base_extra_kwargs) + if pp is not None: + extras["pingpongAutoWS"] = pp + return triton.Config( { "BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, @@ -338,8 +354,11 @@ def matmul_tma_persistent_get_configs(pre_hook=None): num_stages=s, num_warps=w, pre_hook=pre_hook, - **extra_kwargs, - ) # + **extras, + ) + + return [ + _make_config(BM, BN, BK, s, w, SUBTILE, FLATTEN, DP, pp) for BM in bm_range # for BN in bn_range # for BK in bk_range # @@ -348,38 +367,43 @@ def matmul_tma_persistent_get_configs(pre_hook=None): for SUBTILE in [1, 2, 4] # for FLATTEN in [True, False] # for DP in [1, 2] # + for pp in _pingpong_options() # ] def _prune_tma_persistent_configs(configs, named_args, **kwargs): - """Prune configs for the TMA persistent kernel based on FLATTEN/WS rules. - - - WARP_SPECIALIZE=False: require FLATTEN=False - - WARP_SPECIALIZE=True + meta WS: require FLATTEN=False - - WARP_SPECIALIZE=True + no meta WS: require FLATTEN=True + """Prune configs for the TMA persistent kernel. + + FLATTEN rules. FLATTEN=True is only supported on Blackwell with the + OAI Triton WS path. + - Hopper: require FLATTEN=False (regardless of WS / meta WS) + - Blackwell + WARP_SPECIALIZE=False: require FLATTEN=False + - Blackwell + WARP_SPECIALIZE=True + meta WS: require FLATTEN=False + - Blackwell + WARP_SPECIALIZE=True + OAI Triton WS: require FLATTEN=True + + DATA_PARTITION_FACTOR rules. DP=2 is only valid on the meta WS path, + and the partitioned M tile must match the arch's largest usable + BLOCK_M (128 on Hopper, 256 on Blackwell). """ ws = kwargs.get("WARP_SPECIALIZE", False) + flatten_required = ws and not _use_meta_ws() and not IS_HOPPER kept = [] for c in configs: subtile = c.kwargs.get("EPILOGUE_SUBTILE", 1) if c.kwargs.get("BLOCK_SIZE_N", 1) % subtile != 0: continue flatten = c.kwargs.get("FLATTEN", False) - if not ws: - if flatten: - continue - elif _use_meta_ws(): - if flatten: - continue - else: - if not flatten: - continue + if flatten != flatten_required: + continue data_partition_factor = c.kwargs.get("DATA_PARTITION_FACTOR", 1) if data_partition_factor != 1: if not _use_meta_ws(): continue block_m = c.kwargs.get("BLOCK_SIZE_M", 1) - if block_m != 256: + # Data partitioning needs the largest M tile that still fits in + # registers: BLOCK_M=128 on Hopper, BLOCK_M=256 on Blackwell. + required_block_m = 128 if IS_HOPPER else 256 + if block_m != required_block_m: continue kept.append(c) return _prune_warp_specialize_configs(kept, named_args, **kwargs)