Predictability Enables Parallelization of Nonlinear State Space Models
Xavier Gonzalez*, Leo Kozachkov*, David M. Zoltowski, Kenneth L. Clarkson, Scott W. Linderman
Paper: https://www.alphaxiv.org/abs/2508.16817
Talk: https://www.youtube.com/watch?v=C9AqgW51-B4
This repository contains code for the paper "Predictability Enables Parallelization of Nonlinear State Space Models." This paper is published at NeurIPS 2025.
In this paper, we consider the problem of parallelizing nonlinear state space models (SSMs) of the form
In our paper, we link the "predictability" of nonlinear SSMs---as measured by their largest Lyapunov exponent (LLE)---to the conditioning of the optimization problem solved by parallel Newton methods. As our banner figure belows shows, we provide the following key conclusion:
Predictable SSMs are parallelizable. Unpredictable SSMs are not.
These results have important implications for the use of parallel Newton methods, and the design of nonlinear SSMs. Read our paper to learn more!
The experiments in this codebase exist primarily to support the theory developed in our paper.
However, we highlight some features of the codebase that could be useful broadly:
- In
src/deer.py, we provide a fairly lightweight instantiation of the DEER algorithm. This implementation uses the causal nature of DEER to get the number of steps to convergence for a large number of sequence lengths T in an efficient manner. - In
src/lle.py, we provide many helper functions for estimating the largest Lyapunov exponent (LLE) of an SSM. This code will be useful to those who want to evaluate the predictability of their SSM. - In
src/exampleswe provide an API for instantiating different SSMs so that they can all be evaluated easily with DEER. See the README in that section. - In particular, in
src/examples/chaotic_flows.py, we port part of the dysts codebase into JAX.
Info about how to install jax: https://docs.jax.dev/en/latest/installation.html
- Use python 3.12.1
- Use jax 0.5.3
- Use a virtual environment
pip install -U jax
pip install -U jax[cuda12]
After installing jax appropriately based on hardware, simply run
pip install -e .
Please star this repo if you find our code interesting or useful!
@inproceedings{gonzalez2025predictability,
title={{Predictability Enables Parallelization of Nonlinear State Space Models}},
author={Gonzalez, Xavier and Kozachkov, Leo and Zoltowski, David M. and Clarkson, Kenneth L. and Linderman, Scott W.},
booktitle={Neural Information Processing Systems (NeurIPS)},
year={2025},
}

