Skip to content

prateekbhustali/jax-imprl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-imprl 🚀

A JAX accelerated version of IMPRL (Inspection and Maintenance Planning with Reinforcement Learning), a library for applying reinforcement learning to inspection and maintenance planning of deteriorating engineering systems.

Installation 📦

1. Install uv

Why install uv? An extremely fast Python package and project manager, written in Rust. It is much faster than pip and pip-tools, and has a simple CLI for managing dependencies, virtual environments, and scripts. More info here: https://docs.astral.sh/uv/

You can install uv using the following methods (see docs for OS-specific options):

# macOS (Homebrew)
brew install uv

# Or via script (Linux/macOS)
curl -LsSf https://astral.sh/uv/install.sh | sh

2. Create a virtual environment

(Recommended) Create a uv-managed virtualenv:

uv venv --python 3.9 # create virtual environment
source .venv/bin/activate  # activate virtual environment
Alternative: conda
conda create --name jax_imprl_env -y python==3.9
conda activate jax_imprl_env

3. Install the dependencies (uv)

# CPU (default): install base dependencies, creating uv.lock
uv sync

# GPU (optional): add the GPU group to enable CUDA-backed JAX
uv sync --group gpu

# Dev tools (optional): formatter, tests, etc.
uv sync --group dev
GPU notes
  • The gpu group installs jax[cuda12_pip] (CUDA 12 via pip packages) on Linux with NVIDIA GPUs.
  • If you maintain your own local CUDA 12 install, you can switch to jax[cuda12_local] by editing the gpu group in pyproject.toml.
  • macOS uses the CPU build of JAX by default.
Installing additional packages

Add packages with uv add and optionally assign them to an extra or group.

For example, to add pandas allowing any 2.x release:

uv add "pandas>=2,<3"

To add a dev-only tool:

uv add --group dev ruff

If resolution fails, relax version ranges and retry.

4. (optional) Test the installation

You can run unit tests to verify that the installation was successful.

uv sync --group dev # ensure dev dependencies are installed
pytest -v tests

5. (optional) Setup wandb

For logging, the library relies on wandb. You can log into wandb using your private API key,

wandb login
# <enter wandb API key>

JAX vs NumPy Performance Comparison ⚡

Environment Rollouts

We compare the performance of JAX with NumPy (multiprocessing) for simulating rollouts for a k-out-of-n system with 5 components (agents), where each episode consists of 50 time steps. For 10,000 episodes, JAX (solid lines) achieves up to ~14x speedup over NumPy (dashed lines).

Runtime and Speedup vs NumPy

# Episodes NumPy (for loop) [s] NumPy (mp) [s] JAX (scan) [s] Speedup: NumPy (mp) vs. JAX (scan)
1 0.01 1.18 0.27 4.39×
10 0.05 1.2 0.24 5.01×
100 0.51 1.3 0.24 5.41×
1,000 5.09 2.09 0.28 7.4×
10,000 51.94 10.02 0.72 13.88×

RL Training

We further benchmark RL training on variants of the k-out-of-n system. JAX achieves over 5x - 12x faster training throughput than the equivalent PyTorch implementation running on 8 CPU cores.

Environment Agents Episodes Timesteps MBP (JAX) [s] MBP (PyTorch 8 CPUs) [s] Speedup
k_n_infinite 4 50,000 2.5 M 403.9 s (0:06:44) 4883.6 s (1:21:24) 12×
k_n_50 5 100,000 5 M 1781 s (0:29:41) 9571.8 s (2:39:32) 5.4×

Hardware specs: Apple M1 Pro, 16GB RAM

Docker 🐳

If you want to run the code in a containerized environment, you can use the following Docker image and the previous installation steps.

docker pull nvidia/cuda

In case you don't have accesss to NVIDIA GPUs, you can rent a cloud instance and load the above Docker image. For example, vast.ai at ~$0.30/hour (pricing)

https://cloud.vast.ai/?ref_id=113803&creator_id=113803&name=JAX%2BRL

Related Work 🔗

  • IMPRL: small-scale k-out-of-n environments with upto 5 components.

  • IMP-MARL: a platform for benchmarking the scalability of cooperative MARL methods in real-world engineering applications.

    • Environments: (Correlated and uncorrelated) k-out-of-n systems and offshore wind structural systems.
    • RL solvers: Provides wrappers for interfacing with several (MA)RL libraries such as EPyMARL, RLlib, MARLlib etc.

Acknowledgements 🙏

This repository is inspired by the following projects:

PureJAXRL by luchris429

CleanRL started by vwxyzjn

About

A JAX accelerated version of IMPRL (Inspection and Maintenance Planning with Reinforcement Learning)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages