This demo compares two implementations of the same stochastic SIR model using Poisson increments:
- Standard (NumPy): draws Poisson event counts directly with NumPy.
- THRML: draws the same event counts using THRML by representing each step’s infection and recovery counts as unary categorical nodes in a tiny graphical model and sampling them with THRML’s block Gibbs engine. Energies for those nodes are set to the negative log of the Poisson PMF, so the conditional distribution matches the desired Poisson law (truncated to feasible ranges). Deterministic SIR state updates follow from the sampled counts.
Both versions use the Poisson event formulation of SIR (conjugate with Gamma priors), with fixed parameters here:
beta = 3gamma = 1dt = 0.1
This is a simulation demo. No parameter inference is performed.
At step $ t \rightarrow t+1 $ with step size
and
In the THRML path, a unary categorical factor encodes energy
Install with:
python -m pip install -r requirements.txt- THRML requires Python ≥ 3.10 and JAX.
- Apple Silicon GPU (optional): install
jax-metaland run with--platform metal. If the Metal backend is unavailable, JAX will fall back to CPU.
python sir_thrml_demo.py --n-sims 20 --steps 200 --platform cpu
# On Apple Silicon with jax-metal installed (although METAL is slower on my laptop...):
python sir_thrml_demo.py --n-sims 20 --steps 200 --platform METALOther flags:
--Npopulation size (default 1000)--I0,--R0initial infectious and recovered--beta,--gamma,--dtmodel parameters--stepsnumber of time steps--n-simsnumber of trajectories--seedRNG seed--save-prefixif set, saves png figures with this prefix--platformJAX backend selection:auto,cpu,gpu, ormetal
The script prints max differences between ensemble means/variances of
- Single‑trajectory S, I, R overlay (NumPy vs THRML).
- Ensemble mean and 95% band for
$I(t)$ for both simulators. - Histogram of final size
$R(T)$ for both simulators.
Agreement up to Monte Carlo noise demonstrates stochastic equivalence of the two implementations under the same Poisson SIR model.
- THRML docs and quick start.
- Discrete‑time Poisson SIR formulation (Gamma–Poisson conjugacy is standard).
The THRML implementation is fascinating conceptually but orders of magnitude slower than NumPy for this use case:
- NumPy:
rng.poisson(lam)→ direct native call (~microseconds) - THRML: Creates full probabilistic graphical model, runs block Gibbs sampler, compiles JAX computation graph → ~170ms per sample
- NumPy: O(1) per Poisson sample
- THRML: O(k_max) to build categorical distribution + Gibbs overhead + compilation
- 67% of total time spent in compilation
- JAX recompiles for each unique
k_maxvalue, which changes with each time step - Different
S_tandI_tvalues → different truncation limits → different computation graphs - No reuse of compiled functions, even with caching (tried and removed) because exactly repeated states are rare.
Each Poisson sample creates 8+ new THRML objects:
CategoricalNode(),Block(),BlockGibbsSpec(),CategoricalEBMFactor(),FactorSamplingProgram(), etc.- For 1 simulation × 200 steps × 2 samples per step = 400 calls
- Each call recreates the entire THRML sampling infrastructure
- In wall-clock time, this is almost negligible, but it speaks to how strange the reshaping of a model whose state space is indexed by time into an energy model is.
- NumPy: ~2ms total for full set of simulations
- THRML: ~69 seconds for same
- ~34,500x slower for equivalent mathematical operations
The THRML approach is a fascinating proof-of-concept showing one can represent Poisson sampling as categorical factors, but it's orders of magnitude slower than direct sampling for this use case.