NeuroJAX is a differentiable, GPU-accelerated cli tool for EMEG processing and non linear dynamical systems analysis
To unify Preprocessing and Modelling into a single computational graph, enabling end-to-end gradient descent from sensor error to biophysical parameters.
- Core:
JAX,Equinox - Solvers:
Lineax(GLM/Beamforming),Optimistix - Dynamics:
Diffrax(Neural ODEs / DCM) - Inverse:
Scico(Sparse/Iterative solvers)
- GLM: Mass-univariate permutation testing on GPU (
src/neurojax/glm.py). - Inverse: Differentiable Beamformers and CHAMPAGNE algorithm.
- Biophysics: Differentiable implementations of Wong-Wang and Canonical Microcircuit models (replacing TVB/DCM).
- Foundation: Mamba-based sequence modelling for whole-brain dynamics.
We recommended using uv for dependency management.
# Install dependencies and project
uv syncfrom neurojax.glm import GeneralLinearModel, run_permutation_test
# See examples/demo_glm.py for full usage