Jax Implementation of the Adaptive Integration Time (AIT) algorithm for Neural ODEs.
- Linux (x86_64)
- Python 3.11
- NVIDIA GPU with CUDA 13 (JAX is installed with the
cuda13extra) - uv for dependency management
With uv (recommended):
uv syncThis creates a virtual environment in .venv/ and installs the locked
dependencies from uv.lock. Run commands with uv run, e.g.:
uv run python scripts/plot_experiments.py results/ait_mnist_0.001.csv results/node_mnist_0.csvAlternatively, install into an existing environment with pip:
pip install -e .# Make the scripts executable
chmod +x experiments/run_ait.sh
chmod +x experiments/run_node.sh
# Run the experiments
./experiments/run_ait.sh g2
./experiments/run_node.sh g2