Skip to content

Commit 028e5da

Browse files
authored
[KERNELS] simplify mx shuffled weights defaults (#9986)
Simplify use of shuffled blackwell mx value weights - convert directly to BlackwellMX4ValueShuffledLayout; don't require first going through BlackwelllValueLayout - use block sizes from BlackwellMX4ValueShuffledLayout as opt flag constraints. removes complicated code needed to infer the block sizes before making the layout. pick a better default of block_n = 256, block_k = 128 which generally works well and is the inferred one except in cases where N < 256. also makes it simpler to just use, instead of also needing to override disable_mx4_block_swap = True when shuffled weights are used. - add more test coverage same perf from running `torchrun --nproc-per-node=1 python/triton_kernels/bench/bench_mlp.py`
1 parent d166045 commit 028e5da

8 files changed

Lines changed: 131 additions & 146 deletions

File tree

python/triton_kernels/bench/bench_mlp.py

Lines changed: 13 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
import triton_kernels.roofline as roofline
88
from triton_kernels.swiglu import swiglu_fn
99
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
10-
from triton_kernels.matmul_details.opt_flags import make_opt_flags, scoped_opt_flags_constraints
10+
from triton_kernels.matmul_details.opt_flags import scoped_opt_flags_constraints
1111
from triton_kernels.target_info import get_cdna_version
1212
from triton_kernels.tensor_details import layout
13-
from triton_kernels.tensor_details.layout import BlackwellMX4ValueShuffledLayout
1413
from triton_kernels.reduce import reduce
1514
from triton_kernels.topk import topk
1615
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata # ragged tensor
17-
from triton_kernels.tensor import is_tma_compliant, Tensor, torch_dtype_to_dtype
1816
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_assignment, SymmetricMemoryPool
1917
from triton_kernels.distributed_details.mesh import Mesh
2018
# quantization
@@ -23,60 +21,6 @@
2321
from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp
2422

2523

26-
def _shuffle_mx4_weights(tensor, block_k, block_n):
27-
"""
28-
Convert MX4 weights from BlackwellMXValueLayout to BlackwellMX4ValueShuffledLayout.
29-
30-
Works directly with the column-major storage data, bypassing the canonical format
31-
round-trip which uses a different byte-level packing convention.
32-
"""
33-
from triton_kernels.tensor import Storage, Tensor as TKTensor
34-
storage = tensor.storage
35-
# The stored data is column-major [E, K_packed_padded, N] with stride(-2)==1
36-
data = storage.data
37-
E = tensor.shape[0]
38-
K_logical = tensor.shape[-2]
39-
N = tensor.shape[-1]
40-
K_packed = K_logical // 2 # 2 FP4 values per byte
41-
# Trim any padding from the BlackwellMXValueLayout
42-
data = data[:, :K_packed, :N].contiguous()
43-
# Now apply the shuffled layout's tiling
44-
shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n)
45-
transformation = shuffled_layout.make_transformation([E, K_logical, N], True)
46-
shuffled_data = transformation.swizzle_data(data)
47-
return TKTensor(Storage(shuffled_data, shuffled_layout), shape=list(tensor.shape), dtype=tensor.dtype)
48-
49-
50-
def _infer_opt_flags(x, w, ragged_metadata, pc):
51-
"""
52-
Infer opt_flags by calling make_opt_flags with the same parameters matmul would use.
53-
This ensures the block shapes match what the kernel will actually select.
54-
"""
55-
if not isinstance(w, Tensor):
56-
raise TypeError("w must be a Tensor for block shape inference")
57-
K = w.shape[-2]
58-
N = w.shape[-1]
59-
M = x.shape[-2]
60-
batch_size = 1
61-
if not isinstance(x, Tensor):
62-
x = wrap_torch_tensor(x)
63-
# Convert out_dtype from torch.dtype to triton dtype (make_opt_flags expects .bitwidth)
64-
out_dtype = pc.out_dtype or x.dtype
65-
out_dtype = torch_dtype_to_dtype(out_dtype)
66-
x_transpose = x.stride(-1) != 1
67-
b_scale = pc.b_mx_scale
68-
can_use_tma = (x.numel() > 0 and is_tma_compliant(x) and w.numel() > 0 and is_tma_compliant(w)
69-
and (b_scale is None or is_tma_compliant(b_scale)))
70-
# Respects any constraints set by the caller via scoped_opt_flags_constraints
71-
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, pc, batch_size, M, N, K, ragged_metadata, can_use_tma,
72-
False, # can_use_split_k=False for MoE
73-
None, # epilogue_effective_itemsize
74-
x_transpose, False, # has_y_acc_in
75-
None, # block_k
76-
)
77-
return opt_flags
78-
79-
8024
def was_launched_with_torchrun():
8125
required = ["RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT"]
8226
return all(k in os.environ for k in required)
@@ -102,8 +46,14 @@ def quantize_weight(w, dtype, **opt):
10246
assert dtype == FP4, f"{dtype=}"
10347
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
10448
if opt:
105-
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"])
106-
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"])
49+
w = wrap_torch_tensor(w, dtype=FP4)
50+
value_layout = opt.get("value_layout")
51+
if value_layout is not None:
52+
w = convert_layout(w, value_layout)
53+
w_scale = wrap_torch_tensor(w_scale)
54+
scale_layout = opt.get("scale_layout")
55+
if scale_layout is not None:
56+
w_scale = convert_layout(w_scale, scale_layout)
10757
return w, InFlexData(), w_scale
10858

10959

@@ -186,7 +136,10 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
186136
opt2 = dict()
187137
if w_dtype == FP4:
188138
num_warps = 4 if batch <= 512 else 8
189-
value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
139+
value_layout = layout.make_default_matmul_mxfp4_w_layout(
140+
mx_axis=1,
141+
allow_blackwell_value_shuffle=shuffle_mx4,
142+
)
190143
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
191144
opt1 = {
192145
"value_layout": value_layout,
@@ -223,58 +176,13 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
223176
expt_dict = make_expt_dict_uniform(EP, n_expts_tot)
224177
expt_assignment = make_expt_assignment(EP, n_expts_tot, expt_dict, torch.device(dev))
225178

226-
# For MX4 shuffling: run one dry-run iteration to collect routing data, then infer block shapes
227-
if shuffle_mx4 and w_dtype == FP4:
228-
# Disable block swap: with shuffled weights, tile loads are already contiguous,
229-
# so the swap's cacheline optimization is unnecessary. More importantly, disabling
230-
# the swap gives block_k=128 (vs 256), halving per-stage smem footprint, which
231-
# enables fitting 5 pipeline stages instead of 4 — a bigger win than the swap.
232-
dry_run_constraints = {"disable_mx4_block_swap": True}
233-
if epilogue_subtile_fc1 is not None:
234-
dry_run_constraints["epilogue_subtile"] = epilogue_subtile_fc1
235-
with scoped_opt_flags_constraints(dry_run_constraints):
236-
# Dry-run routing to get ragged metadata and dispatched activations
237-
l_dry = matmul(x_dp_local_bf16, wg_global, bg_global, precision_config=pcg)
238-
l_active_dry = topk(l_dry, n_expts_act, apply_softmax=True, all_gather=True, symm_mem_pool=symm_mem_pool)
239-
active_indx_dry = l_active_dry.indx
240-
expt_sizes_dry = l_active_dry.mask_metadata.col_sum
241-
dispatch_indx_dry = l_active_dry.mask_metadata.row_sorted_indx
242-
x_global_meta_dry = make_ragged_tensor_metadata(expt_sizes_dry, dispatch_indx_dry.shape[0])
243-
y_dry = convert_dp_to_ep(x_dp_local_fp8, expt_assignment, active_indx_dry, dispatch_indx_dry, symm_mem_pool)
244-
y_meta_dry = remap_ragged_tensor_metadata(x_global_meta_dry, expt_assignment.expt_map[rank, :])
245-
246-
if y_dry.nelement() > 0:
247-
# Infer block shapes for W1 (includes the block swap)
248-
opt_flags_w1 = _infer_opt_flags(y_dry, w1_ep_local, y_meta_dry, pc1)
249-
w1_block_k, w1_block_n = opt_flags_w1.block_k, opt_flags_w1.block_n
250-
w1_ep_local = _shuffle_mx4_weights(w1_ep_local, w1_block_k, w1_block_n)
251-
252-
# Run W1 once to get intermediate for W2 block shape inference
253-
y_fc1_dry = matmul(y_dry, w1_ep_local, b1_ep_local, a_ragged_metadata=y_meta_dry, precision_config=pc1,
254-
fused_activation=act1)
255-
256-
# Infer block shapes for W2 (includes the block swap)
257-
opt_flags_w2 = _infer_opt_flags(y_fc1_dry, w2_ep_local, y_meta_dry, pc2)
258-
w2_block_k, w2_block_n = opt_flags_w2.block_k, opt_flags_w2.block_n
259-
w2_ep_local = _shuffle_mx4_weights(w2_ep_local, w2_block_k, w2_block_n)
260-
261-
print(f"Shuffled layout: FC1 block_k={w1_block_k}, block_n={w1_block_n}, "
262-
f"stages={opt_flags_w1.num_stages}, subtile={opt_flags_w1.epilogue_subtile}; "
263-
f"FC2 block_k={w2_block_k}, block_n={w2_block_n}, "
264-
f"stages={opt_flags_w2.num_stages}, subtile={opt_flags_w2.epilogue_subtile}")
265-
torch.cuda.synchronize()
266-
267179
# Build per-kernel constraints
268180
fc1_constraints = {}
269-
if shuffle_mx4:
270-
fc1_constraints["disable_mx4_block_swap"] = True
271181
if num_stages_fc1 is not None:
272182
fc1_constraints["num_stages"] = num_stages_fc1
273183
if epilogue_subtile_fc1 is not None:
274184
fc1_constraints["epilogue_subtile"] = epilogue_subtile_fc1
275185
fc2_constraints = {}
276-
if shuffle_mx4:
277-
fc2_constraints["disable_mx4_block_swap"] = True
278186
if num_stages_fc2 is not None:
279187
fc2_constraints["num_stages"] = num_stages_fc2
280188

python/triton_kernels/tests/test_matmul.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class Case:
9191
split_k: int = 1
9292
a_hbm_swizzling: bool = False
9393
b_hbm_swizzling: bool = False
94+
shuffle_mxfp4_w_layout: bool = False
9495
epilogue_subtile: Union[int, None] = None
9596
a_transpose: bool = False
9697
b_transpose: bool = False
@@ -148,12 +149,16 @@ def _build_test_op_cases():
148149
# float8 x mxfloat
149150
test_cases.extend([
150151
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True),
152+
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
151153
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True),
154+
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
152155
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"),
153156
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9),
154157
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True),
158+
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
155159
Case(300, 400, 416, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
156160
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"),
161+
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
157162
Case(300, 400, 416, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
158163
])
159164
# mxfloat x mxfloat
@@ -236,15 +241,15 @@ def _build_test_op_cases():
236241
@pytest.mark.parametrize("is_persistent", [False,True])
237242
@pytest.mark.parametrize("num_warps", [4, 8] if is_hopper() else [None])
238243
def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
239-
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
244+
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
240245
a_transpose, b_transpose, c_transpose,
241246
swiglu_opts, device, opt_flags_scope):
242247
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
243248
# the frame that called pytest.skip, including all the tensors, leading to OOM.
244249
skip_message = None
245250
try:
246251
_test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
247-
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
252+
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
248253
a_transpose, b_transpose, c_transpose,
249254
swiglu_opts, device, opt_flags_scope)
250255
except pytest.skip.Exception as e:
@@ -254,7 +259,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, i
254259
pytest.skip(skip_message)
255260

256261
def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
257-
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
262+
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
258263
a_transpose, b_transpose, c_transpose,
259264
swiglu_opts, device, opt_flags_scope):
260265
act_uses_mx = act_dtype_str.startswith("mx") or act_dtype_str == "nvfp4_e2m1"
@@ -349,6 +354,20 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
349354

350355
# set opt flags constraints
351356
constraints = make_constraints(block_m, split_k, is_persistent, epilogue_subtile, b_hbm_swizzling, weight_dtype_str, num_warps)
357+
use_blackwell_shuffled_w_layout = shuffle_mxfp4_w_layout and b_hbm_swizzling
358+
if shuffle_mxfp4_w_layout:
359+
if not b_hbm_swizzling:
360+
pytest.skip("Shuffled MXFP4 weight layout only applies with b_hbm_swizzling")
361+
if is_hip() or torch.cuda.get_device_capability()[0] < 10:
362+
pytest.skip("Shuffled MXFP4 weight layout requires Blackwell or newer")
363+
if weight_dtype_str != "mxfloat4_e2m1":
364+
pytest.skip("Shuffled MXFP4 weight layout only supports mxfloat4_e2m1 weights")
365+
if not act_dtype_str.startswith("float8"):
366+
pytest.skip("Shuffled MXFP4 weight layout is only tested with FP8 activations")
367+
if not colmajor_mxfp_weight:
368+
pytest.skip("Shuffled MXFP4 weight layout requires column-major MXFP weights")
369+
if not is_persistent:
370+
pytest.skip("Shuffled MXFP4 weight layout requires the persistent TMA kernel")
352371
opt_flags.update_opt_flags_constraints(constraints)
353372

354373
a_dtype = DType(act_dtype_str)
@@ -359,6 +378,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
359378
do_bias = inner_expt_opt is None
360379
do_gather = do_gather and mode != "batched"
361380
do_scatter = do_scatter and mode != "batched"
381+
b_value_hbm_swizzling = None
382+
if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4:
383+
b_value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(
384+
mx_axis=-2,
385+
allow_blackwell_value_shuffle=use_blackwell_shuffled_w_layout,
386+
)
362387

363388
# --- create inputs ---
364389
a, a_scales, a_ragged_metadata = make_random_tensor(
@@ -384,9 +409,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
384409
ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt,
385410
squeeze_batch_dim = mode == "plain",
386411
is_mx_rowmajor = not colmajor_mxfp_weight,
387-
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
412+
value_hbm_swizzling = b_value_hbm_swizzling,
388413
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
389414
)
415+
if use_blackwell_shuffled_w_layout:
416+
assert isinstance(b.storage.layout, layout.BlackwellMX4ValueShuffledLayout)
390417
gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device)
391418
scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device)
392419
bias = None if not do_bias else torch.randn(b.shape[:-2] + b.shape[-1:], dtype=torch.float32, device=device)

python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ def _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n):
3737
return weight_val, weight_scale
3838

3939

40-
def _shuffle_blackwell_mxfp4_weight(weight, block_k, block_n):
41-
shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n)
42-
transformation = shuffled_layout.make_transformation(weight.shape, is_fp4=True)
43-
shuffled_data = transformation.swizzle_data(weight.storage.data)
44-
return Tensor(Storage(shuffled_data, shuffled_layout), dtype=weight.dtype, shape=weight.shape)
40+
def _shuffle_blackwell_mxfp4_weight(weight):
41+
shuffled_layout = BlackwellMX4ValueShuffledLayout()
42+
return convert_layout(weight, shuffled_layout)
4543

4644

4745
@pytest.mark.parametrize("n, expected", [(64, 128), (200, 256)])
@@ -70,7 +68,7 @@ def test_matmul_blackwell_scale_small_n(device):
7068
out_dtype=a.dtype,
7169
)
7270
tri_y = matmul(a, b, None, precision_config=precision_config)
73-
ref_y = matmul_torch(a, b, None, precision_config=precision_config)
71+
ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config)
7472
assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None)
7573

7674

@@ -82,30 +80,25 @@ def test_matmul_blackwell_shuffled_mxfp4_weight(device):
8280

8381
torch.manual_seed(0)
8482
batch_size, m, n, k = 2, 128, 128, 128
85-
block_k, block_n = 128, 128
86-
a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16)
83+
a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16).to(torch.float8_e5m2)
8784
b, b_scale = _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n)
88-
b_shuffled = _shuffle_blackwell_mxfp4_weight(b, block_k, block_n)
85+
b_shuffled = _shuffle_blackwell_mxfp4_weight(b)
8986

9087
# Sanity-check the host-side packing; this is the layout consumed by the
9188
# W_SHUFFLED TMA load path in _p_matmul.
92-
transformation = b_shuffled.storage.layout.make_transformation(b.shape, is_fp4=True)
93-
assert torch.equal(b.storage.data, transformation.unswizzle_data(b_shuffled.storage.data))
89+
assert torch.equal(b.storage.data, convert_layout(b_shuffled, b.storage.layout).storage.data)
9490

9591
precision_config = PrecisionConfig(
9692
b_mx_scale=b_scale,
9793
b_microblock_size=MXFP_BLOCK_SIZE.value,
98-
out_dtype=a.dtype,
94+
out_dtype=torch.bfloat16,
9995
)
10096
constraints = {
10197
"is_persistent": True,
10298
"block_m": 128,
103-
"block_n": block_n,
104-
"block_k": block_k,
105-
"disable_mx4_block_swap": True,
10699
}
107100
with scoped_opt_flags_constraints(constraints):
108101
tri_y = matmul(a, b_shuffled, None, precision_config=precision_config)
109102

110-
ref_y = matmul_torch(a, b, None, precision_config=precision_config)
103+
ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config)
111104
assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None)

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ def matmul(a, b, bias,
401401
block_k = block_k,
402402
mx_block_size = mx_block_size,
403403
x_uses_tma_when_persistent = a_uses_tma_when_persistent,
404+
rhs_layout=b.storage.layout,
405+
epilogue_reduction_n=fused_activation.specs.reduction_n,
404406
)
405407
if b_is_shuffled:
406408
if b.dtype.bitwidth != 4:
@@ -701,7 +703,7 @@ def apply(x, scale):
701703

702704
if precision_config.a_mx_scale is not None:
703705
a_scale = precision_config.a_mx_scale
704-
mx_axis = x_tri.storage.data.ndim -1
706+
mx_axis = x_tri.ndim - 1
705707
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
706708
x_tri = convert_layout(x_tri, canonical_layout)
707709
x_tri_scale = convert_layout(a_scale, canonical_layout)
@@ -711,7 +713,7 @@ def apply(x, scale):
711713

712714
if precision_config.b_mx_scale is not None:
713715
b_scale = precision_config.b_mx_scale
714-
mx_axis = w_tri.storage.data.ndim - 2
716+
mx_axis = w_tri.ndim - 2
715717
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
716718
w_tri = convert_layout(w_tri, canonical_layout)
717719
w_tri_scale = convert_layout(b_scale, canonical_layout)

0 commit comments

Comments
 (0)