Skip to content

Commit 19117eb

Browse files
maleadtclaude
andcommitted
Fix Python layernorm benchmark throughput reporting to match Julia.
Python's layernorm metric() was returning a single tuple applied to both fwd and bwd passes, while Julia correctly uses separate multipliers (4x for forward, 5x for backward). Update layernorm.py to return a per-impl dict and update benchmarks.py to handle dict metric returns. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 839bb94 commit 19117eb

2 files changed

Lines changed: 20 additions & 7 deletions

File tree

examples/benchmarks.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def run_benchmark(name: str):
9898
data = prepare_fn(benchmark=True)
9999

100100
# Get metric info if available
101+
# metric() returns either (total, unit) or dict{"impl": (total, unit)}
101102
metric_fn = getattr(mod, "metric", None)
102-
metric_total, metric_unit = (0, "") if not metric_fn else metric_fn(data)
103+
metric_result = metric_fn(data) if metric_fn else None
103104

104105
# Run cuTile
105106
result = run_fn(data, nruns=NRUNS, warmup=WARMUP)
@@ -121,7 +122,7 @@ def run_benchmark(name: str):
121122
others = run_others_fn(data, nruns=NRUNS, warmup=WARMUP)
122123
results.update(others)
123124

124-
return results, metric_total, metric_unit
125+
return results, metric_result
125126

126127

127128
#=============================================================================
@@ -147,14 +148,21 @@ def main():
147148
print(" (skipped - no prepare/run functions)")
148149
continue
149150

150-
results, metric_total, metric_unit = ret
151+
results, metric_result = ret
151152

152153
# Convert to BenchmarkResult for printing
153154
benchmark_results = []
154155
for impl_name, times in results.items():
155156
min_t = min(times)
156157
mean_t = sum(times) / len(times)
157-
tp = format_throughput(metric_total, metric_unit, min_t) if metric_unit else ""
158+
tp = ""
159+
if isinstance(metric_result, dict):
160+
if impl_name in metric_result:
161+
mt, mu = metric_result[impl_name]
162+
tp = format_throughput(mt, mu, min_t)
163+
elif isinstance(metric_result, tuple):
164+
mt, mu = metric_result
165+
tp = format_throughput(mt, mu, min_t) if mu else ""
158166
benchmark_results.append(BenchmarkResult(impl_name, min_t, mean_t, tp))
159167

160168
# Sort by min time

examples/layernorm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,14 @@ def verify(data, result):
255255
f"DB mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DB']) - expected_DB))}"
256256

257257
def metric(data):
258-
"""Return (total_bytes, unit) for throughput calculation."""
259-
# Forward: 3 reads of X + W + B reads + Y write + Mean/Rstd writes ≈ 4*M*N floats
260-
return 4 * data["M"] * data["N"] * 4, "GB/s"
258+
"""Return per-implementation (total_bytes, unit) for throughput calculation."""
259+
MN = data["M"] * data["N"] * 4 # sizeof(float32)
260+
return {
261+
# Forward: X read (3 passes: mean, var, normalize) + Y write ≈ 4*M*N floats
262+
"cuTile Fwd": (4 * MN, "GB/s"),
263+
# Backward: X read (2 passes) + DY read (2 passes) + DX write ≈ 5*M*N floats
264+
"cuTile Bwd": (5 * MN, "GB/s"),
265+
}
261266

262267

263268
# No run_others for layernorm - no simple reference implementation to compare against

0 commit comments

Comments
 (0)