We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 880b4e4 commit ba570e1Copy full SHA for ba570e1
1 file changed
python/triton_kernels/triton_kernels/matmul.py
@@ -125,7 +125,10 @@ class PrecisionConfig:
125
# TODO: merge in opt_flags
126
def get_swap_xw(precision_config, opt_flags):
127
if target_info.cuda_capability_geq(10, 0):
128
- return precision_config.b_mx_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
+ if precision_config.b_mx_scale is not None:
129
+ return opt_flags.block_m <= 64 and opt_flags.is_persistent
130
+ else:
131
+ return opt_flags.block_m < 64 and opt_flags.is_persistent
132
elif target_info.cuda_capability_geq(9, 0):
133
b_scale_layout = None if not isinstance(precision_config.b_mx_scale, Tensor) else precision_config.b_mx_scale.storage.layout
134
return isinstance(b_scale_layout, HopperMXScaleLayout)
0 commit comments