Skip to content

Commit ba570e1

Browse files
authored
[KERNELS] enable swap_xw on blackwell for non-mx matmuls (#9390)
helps significantly for ragged matmuls where the slice size is small. otherwise, blackwell will compile to mma.sync
1 parent 880b4e4 commit ba570e1

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

  • python/triton_kernels/triton_kernels

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ class PrecisionConfig:
125125
# TODO: merge in opt_flags
126126
def get_swap_xw(precision_config, opt_flags):
127127
if target_info.cuda_capability_geq(10, 0):
128-
return precision_config.b_mx_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
128+
if precision_config.b_mx_scale is not None:
129+
return opt_flags.block_m <= 64 and opt_flags.is_persistent
130+
else:
131+
return opt_flags.block_m < 64 and opt_flags.is_persistent
129132
elif target_info.cuda_capability_geq(9, 0):
130133
b_scale_layout = None if not isinstance(precision_config.b_mx_scale, Tensor) else precision_config.b_mx_scale.storage.layout
131134
return isinstance(b_scale_layout, HopperMXScaleLayout)

0 commit comments

Comments
 (0)