-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Context
The benchmarks/ directory was originally built to compare Keras 3 against tf.keras, which made sense at the time. But now that JAX, PyTorch, and TF backends are all properly supported, the more useful question is how does the same Keras code actually perform across backends?
The development roadmap (#19519) lists "Official performance benchmarks" as a goal, so I wanted to gauge what that could look like and offer to help.
The main gap is that the benchmarks only compare Keras 3 vs tf.keras. LayerBenchmark hardcodes both a keras.layers.X and a tf.keras.layers.X instance and runs them side by side. There is no way to benchmark the same layer across JAX, Torch, and TF backends, which is probably what most users actually want to know.
A few other things I noticed while looking through it:
- Results just go to
print()with no JSON or CSV, so there is no easy way to compare runs or track trends over time - Benchmarks are not run in
actions.ymlornightly.yml, so performance regressions only get caught when someone files a bug - Only throughput is measured with no memory tracking, which actually matters a lot for RNN layers and loss functions (I ran into this myself while working on Reduce memeory usage in sparse_categorical_crossentropy #22169)
- The warmup is just skipping batch 0, which is not enough to account for JIT compilation noise on JAX or PyTorch
What I'd propose
Rather than one big PR, I'd tackle this incrementally:
- Multi-backend
LayerBenchmark- refactor the base class to accept a--backend jax,torch,tensorflowflag and produce a comparison table - Structured output - add a
--output_format jsonflag so results include metadata (backend, hardware, shapes, timestamps) - Memory profiling - track peak memory alongside throughput using
torch.cuda.max_memory_allocated, JAX device memory stats, etc. - Better measurement - configurable warmup, multiple runs, mean/std reporting
- CI integration - a nightly workflow that runs benchmarks and stores results as artifacts; stretch goal would be a perf diff comment on PRs
Where I'd start
The first PR would just be Phase 1 refactor LayerBenchmark to support multi-backend runs, update conv_benchmark.py as a reference implementation, and keep the existing Keras-vs-tf.keras path working so nothing breaks.
I've been contributing for a while (#22257, #22115, #22169, #22013) and I'm reasonably familiar with how the backends are structured, so this feels like a natural next thing to work on. Happy to adjust based on what's actually useful here. Also if anyone would like to collaborate on this I would be more than happy to!