Skip to content

Commit 4bc7542

Browse files
committed
fix(tutorial): correct inverted error bars in grouped_gemm benchmark
The grouped_gemm tutorial plots absolute latency `runtime(ms)` on the y-axis, but incorrectly returned `(ms, max_ms, min_ms)` from the benchmark function. This inverted the lower and upper bounds of the error bars in the generated plot. This commit fixes the return order to `(ms, min_ms, max_ms)` to ensure accurate variance visualization.
1 parent 088fbe1 commit 4bc7542

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

python/tutorials/08-grouped-gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def benchmark_square_matrices(N, provider):
487487
ms, min_ms, max_ms = triton.testing.do_bench(
488488
lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, dtype=torch.
489489
float16), quantiles=quantiles)
490-
return ms, max_ms, min_ms
490+
return ms, min_ms, max_ms
491491

492492

493493
@triton.testing.perf_report(
@@ -558,7 +558,7 @@ def benchmark_batches(M, provider):
558558
ms, min_ms, max_ms = triton.testing.do_bench(
559559
lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_t_lds, group_size, dtype=torch.
560560
float16), quantiles=quantiles)
561-
return ms, max_ms, min_ms
561+
return ms, min_ms, max_ms
562562

563563

564564
benchmark_square_matrices.run(show_plots=True, print_data=True)

0 commit comments

Comments
 (0)