High-throughput imitation learning for a musculoskeletal mouse forelimb model in JAX-accelerated MuJoCo-MJX.
This repository reproduces all figures and analyses from:
Leonardis, E., Nagamori, A., Yang, Y., Park, J., Saunders, H., Azim, E., Pereira, T. D. (2025). Massively Parallel Imitation Learning of Mouse Forelimb Musculoskeletal Reaching Dynamics. NeurIPS 2025 Workshop: Data on the Brain & Mind – Concrete Applications of AI to Neuroscience and Cognitive Science, San Diego, CA.
The conference paper is available at: [https://openreview.net/pdf?id=jJS0ZT0F8x]
The project depends on track-mjx, which provides the imitation-learning infrastructure and high-throughput GPU rollouts: https://github.com/talmolab/track-mjx
Install track-mjx and its requirements:
git clone https://github.com/talmolab/track-mjx
cd track-mjx
pip install -e .Then install additional dependencies for analysis:
pip install jupyter matplotlib seaborn pandas h5py pyedm tqdm scikit-learnMake sure MuJoCo-MJX is working correctly with GPU-accelerated JAX.
The fastest way to understand the pipeline is to run the batch rollout demo:
▶ demo_batch_rollout_PCA_figures.ipynb
This notebook demonstrates:
- How to load a trained checkpoint
- How to run batched imitation rollouts in parallel
- How to generate rollout
.h5files - How to compute PCA embeddings of intention and decoder-layer activations
- How to visualize reach trajectories and neural representations
This provides a complete end-to-end walkthrough of the training outputs used in the paper.
The demo notebook saves rollouts in the same format described in Data Table 1 of the paper (frames, joint angles, latent activations, decoder activations, muscle activations, etc.).
These rollout files are the input for the EMG and nonlinear forecasting analyses.
Use the notebook:
▶ emg_figures.ipynb
This notebook:
- Loads the rollout
.h5generated in the demo - Loads aligned biological EMG (biceps and triceps)
- Compares simulated muscle activations against observed EMG
- Computes trial-by-trial and averaged activation plots
- Reproduces the EMG MAE comparisons and activation time-series figures from the paper
This corresponds to the EMG panels in Figure 2 of the manuscript.
To reproduce the nonlinear dynamical forecasting results, use:
▶ pyedm_figures.ipynb
This notebook implements:
- Takens-delay embedding of joint angles and simulated actions
- Sugihara’s simplex projection method
- Forecasting of simulated muscle activations from joint kinematics
- Forecasting of real EMG from simulated actions + reference kinematics
- τ, embedding-dimension, and prediction-horizon sweeps
- The forecasting accuracy plots (Simplex ρ) in Figure 3
This reproduces the nonlinear forecasting analysis in the manuscript.
You can download the datasets described in the paper here:
-
Training clip data (registered mocap → reference trajectories): https://huggingface.co/datasets/talmolab/MIMIC-MJX/tree/main/data/mouse_arm
-
Model checkpoints and rollout data: https://huggingface.co/talmolab/mouse-reach-mjx-neurips
-
EMG, trial indexes, and parameter search results: https://huggingface.co/datasets/talmolab/mouse-reach-mjx-neurips
Run the following notebooks in order:
-
demo_batch_rollout_PCA_figures.ipynbGenerates rollouts + PCA visualizations. -
emg_figures.ipynbRecreates EMG comparison figures (MAE, trial-by-trial, average activation). -
pyedm_figures.ipynbReproduces nonlinear forecasting results (Simplex ρ, predicted vs observed traces).
If you use this repository, please cite:
Leonardis, E., Nagamori, A., Yang, Y., Park, J., Saunders, H., Azim, E., Pereira, T. D. (2025) Massively Parallel Imitation Learning of Mouse Forelimb Musculoskeletal Reaching Dynamics. NeurIPS 2025: Data on the Brain & Mind Concrete Applications of AI to Neuroscience and Cognitive Science Workshop, San Diego, CA
Click the image below to watch Supplementary Video 1:
