Skip to content

famulare/tsu-sir

Repository files navigation

Stochastic SIR: NumPy vs THRML

This demo compares two implementations of the same stochastic SIR model using Poisson increments:

  1. Standard (NumPy): draws Poisson event counts directly with NumPy.
  2. THRML: draws the same event counts using THRML by representing each step’s infection and recovery counts as unary categorical nodes in a tiny graphical model and sampling them with THRML’s block Gibbs engine. Energies for those nodes are set to the negative log of the Poisson PMF, so the conditional distribution matches the desired Poisson law (truncated to feasible ranges). Deterministic SIR state updates follow from the sampled counts.

Both versions use the Poisson event formulation of SIR (conjugate with Gamma priors), with fixed parameters here:

  • beta = 3
  • gamma = 1
  • dt = 0.1

This is a simulation demo. No parameter inference is performed.

Model

At step $ t \rightarrow t+1 $ with step size $\Delta t$:

$$ X_t \sim \text{Poisson}!\left(\beta,\frac{S_t I_t}{N},\Delta t\right),\qquad Y_t \sim \text{Poisson}!\left(\gamma, I_t,\Delta t\right), $$

and

$$ S_{t+1}=S_t-X_t,\quad I_{t+1}=I_t+X_t-Y_t,\quad R_{t+1}=R_t+Y_t. $$

In the THRML path, a unary categorical factor encodes energy $E(k)=-\log p_\text{Pois}(k;\lambda_t)$ for counts $k\in{0,\dots,k_\text{max}}$. Setting $E(k)=+\infty$ for $k>k_\text{max}$ enforces feasibility ($X_t\le S_t,\ Y_t\le I_t$). THRML’s Gibbs sampler then draws from $\propto \exp(-E(k))$, which equals the desired PMF.

Requirements

Install with:

python -m pip install -r requirements.txt
  • THRML requires Python ≥ 3.10 and JAX.
  • Apple Silicon GPU (optional): install jax-metal and run with --platform metal. If the Metal backend is unavailable, JAX will fall back to CPU.

Run

python sir_thrml_demo.py --n-sims 20 --steps 200 --platform cpu
# On Apple Silicon with jax-metal installed (although METAL is slower on my laptop...):
python sir_thrml_demo.py --n-sims 20 --steps 200 --platform METAL

Other flags:

  • --N population size (default 1000)
  • --I0, --R0 initial infectious and recovered
  • --beta, --gamma, --dt model parameters
  • --steps number of time steps
  • --n-sims number of trajectories
  • --seed RNG seed
  • --save-prefix if set, saves png figures with this prefix
  • --platform JAX backend selection: auto, cpu, gpu, or metal

Outputs

The script prints max differences between ensemble means/variances of $I(t)$ across the two simulators, and shows three figures:

  1. Single‑trajectory S, I, R overlay (NumPy vs THRML).
  2. Ensemble mean and 95% band for $I(t)$ for both simulators.
  3. Histogram of final size $R(T)$ for both simulators.

Agreement up to Monte Carlo noise demonstrates stochastic equivalence of the two implementations under the same Poisson SIR model.

References

  • THRML docs and quick start.
  • Discrete‑time Poisson SIR formulation (Gamma–Poisson conjugacy is standard).

Performance Discussion

Why sir_thrml is Extremely Slow

The THRML implementation is fascinating conceptually but orders of magnitude slower than NumPy for this use case:

1. Overkill Architecture

  • NumPy: rng.poisson(lam) → direct native call (~microseconds)
  • THRML: Creates full probabilistic graphical model, runs block Gibbs sampler, compiles JAX computation graph → ~170ms per sample

2. Algorithmic Complexity

  • NumPy: O(1) per Poisson sample
  • THRML: O(k_max) to build categorical distribution + Gibbs overhead + compilation

3. JAX Compilation Hell

  • 67% of total time spent in compilation
  • JAX recompiles for each unique k_max value, which changes with each time step
  • Different S_t and I_t values → different truncation limits → different computation graphs
  • No reuse of compiled functions, even with caching (tried and removed) because exactly repeated states are rare.

4. Massive Object Creation Overhead

Each Poisson sample creates 8+ new THRML objects:

  • CategoricalNode(), Block(), BlockGibbsSpec(), CategoricalEBMFactor(), FactorSamplingProgram(), etc.
  • For 1 simulation × 200 steps × 2 samples per step = 400 calls
  • Each call recreates the entire THRML sampling infrastructure
  • In wall-clock time, this is almost negligible, but it speaks to how strange the reshaping of a model whose state space is indexed by time into an energy model is.

Performance Ratio

  • NumPy: ~2ms total for full set of simulations
  • THRML: ~69 seconds for same
  • ~34,500x slower for equivalent mathematical operations

The THRML approach is a fascinating proof-of-concept showing one can represent Poisson sampling as categorical factors, but it's orders of magnitude slower than direct sampling for this use case.

About

Proof-of-principle implementation of an infectious disease SIR model using Extropic's THRML library

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages