Skip to content

Commit fa4db31

Browse files
authored
Fix bench_mlp.py (#9919)
Fix's API issues in the bench_mlp.py script. `python/triton_kernels/bench/bench_mlp.py` no longer ran with current Triton code. Running: ``` torchrun --nproc-per-node=1 python/triton_kernels/bench/bench_mlp.py ``` fails with ``` [rank0]: Traceback (most recent call last): [rank0]: File "/workspace/triton-source/python/triton_kernels/bench/bench_mlp.py", line 230, in <module> [rank0]: roofline_mlp(batch_sizes, 5760, 5760, 128, 4, dtypes[0], dtypes[1], ep, name="mlp_moe") [rank0]: File "/workspace/triton-source/python/triton_kernels/bench/bench_mlp.py", line 194, in roofline_mlp [rank0]: csv_path = roofline.compute_roofline(dim1, dim2, n_expts_tot, n_expts_act, parse_dtype(x_dtype), [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/triton_kernels/roofline.py", line 73, in compute_roofline [rank0]: perf = inject_proxy_and_call(val, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/triton_kernels/roofline.py", line 64, in inject_proxy_and_call [rank0]: return bench_fn(*args_list, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/triton-source/python/triton_kernels/bench/bench_mlp.py", line 100, in bench_mlp [rank0]: symm_mem_pool = SymmetricMemoryPool() [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: TypeError: SymmetricMemoryPool.__init__() missing 1 required positional argument: 'mesh' E0403 20:14:38.021000 2225 torch/distributed/elastic/multiprocessing/api.py:988] failed (exitcode: 1) local_rank: 0 (pid: 2258) of binary: /usr/bin/python3 Traceback (most recent call last): File "/usr/local/bin/torchrun", line 6, in <module> sys.exit(main()) ^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 367, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 1016, in main run(args) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 1007, in run elastic_launch( File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 184, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 332, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ python/triton_kernels/bench/bench_mlp.py FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2026-04-03_20:14:38 host : ab3ee0d0c408 rank : 0 (local_rank: 0) exitcode : 1 (pid: 2258) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ```
1 parent eb5efe2 commit fa4db31

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
116116
n_expts_act=n_expts_act,
117117
n_expts_tot=n_expts_tot,
118118
dtype=x_dtype,
119-
device=torch.cuda.current_device(),
119+
device=torch.device(dev),
120120
)
121121

122122
# -- init prameters --
@@ -137,10 +137,10 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
137137
if w_dtype == FP4:
138138
num_warps = 4 if batch <= 512 else 8
139139
value_layout = layout.make_default_matmul_mxfp4_w_layout(
140-
mx_axis=1,
140+
mx_axis=-2,
141141
allow_blackwell_value_shuffle=shuffle_mx4,
142142
)
143-
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
143+
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps)
144144
opt1 = {
145145
"value_layout": value_layout,
146146
"scale_layout": scale_layout,
@@ -187,7 +187,8 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
187187
fc2_constraints["num_stages"] = num_stages_fc2
188188

189189
fpath = Path(f"profile_{rank}")
190-
# warmup
190+
# Compile and warm up outside the profiler so subsequent profiled launches
191+
# retain launch metadata needed by roofline.parse_profile.
191192
run_mlp(x_dp_local_bf16, x_dp_local_fp8, #
192193
wg_global, bg_global, pcg, #
193194
w1_ep_local, b1_ep_local, pc1, act1, #

python/triton_kernels/bench/bench_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def _quantize_weight(w, dtype, **opt):
2626
assert dtype == "mx4", f"{dtype=}"
2727
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
2828
if opt:
29-
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
30-
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
29+
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"])
30+
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"])
3131
return w, InFlexData(), w_scale
3232

3333

@@ -53,13 +53,11 @@ def _make_mx4_quantization_opts(batch: int, w_dtype: str) -> dict:
5353
if w_dtype != "mx4" or is_hip():
5454
return {}
5555
num_warps = 4 if batch <= 512 and cuda_capability_geq(10, 0) else 8
56-
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
57-
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
56+
value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2)
57+
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps)
5858
return {
5959
"value_layout": value_layout,
60-
"value_layout_opts": value_layout_opts,
6160
"scale_layout": scale_layout,
62-
"scale_layout_opts": scale_layout_opts,
6361
}
6462

6563

0 commit comments

Comments
 (0)