Skip to content

Conversation

@marianophielipp
Copy link

Adds compute_norm_stats_jax.py, a drop-in, JAX-powered replacement for compute_norm_stats.py.
On the ALOHA Pen Uncap dataset (50k frames) it delivers a 3.4 × speed-up; on larger datasets the gains are even more pronounced, making feasible to train with large number of episodes.

- Add JAX implementation with order of mangnituted speedup, test on H100/200 GPUs
- Two modes: FastRunningStats (default) and JaxRunningStats (exact)
- Comprehensive test suite and documentation
- Drop-in replacement for original compute_norm_stats.py
- Update README with usage instructions

Hardware tested: H100/200 GPU
- Add compute_norm_stats_jax.py with GPU acceleration using JAX
- Support multi-GPU processing with automatic batch size optimization
- Provide FastRunningStats for speed and JaxRunningStats for accuracy
- Add comprehensive documentation in docs/jax_acceleration.md
- Update README.md with performance information
- Remove test file as requested
- Fix code quality issues with ruff formatting and linting

Performance: 3.4x speedup on ALOHA pen uncap dataset (50k frames)
Tested on H100 and current machine with real robotics datasets
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants