Skip to content

Guidance on scaling to ~40-100 dimensional problems — JIT compilation time #240

@tomkimpson

Description

@tomkimpson

Hi Joshua,

We're using jaxns for nested sampling in a pulsar timing array (PTA) continuous wave gravitational wave search. The likelihood is a Kalman filter implemented in JAX, and evaluates in ~0.02s on GPU after JIT compilation.

What works well:

With 2 free parameters, jaxns runs end-to-end in ~30 minutes — roughly 25 min JIT compilation and ~3 min actual sampling (130k likelihood evals, 2500 samples,
ESS=314). Evidence computation works perfectly. We're using num_live_points=200, s=3, k=0, max_samples=5000.

Where we struggle:

Our full model has ~107 free parameters. With this dimensionality, jaxns never gets past JIT compilation — we've waited 7.5 hours on a P100 GPU with no progress beyond "Number of Markov-chains set to: 200".

We believe the compilation bottleneck is the XLA trace of the full nested sampling loop including our likelihood, not the sampling itself. The likelihood's computation graph is moderately complex (Kalman filter
with lax.scan), and we suspect the combination of this with 107-dim slice sampling creates a very large XLA program.

Questions:

  1. What dimensionality range is jaxns designed to handle efficiently? Is ~100 dims realistic? 50 dims? Or is ~10-20 more typical?
  2. Is the compilation time expected to scale steeply with the number of parameters? Any strategies to reduce it (e.g., wrapping the likelihood to limit the trace depth)?
  3. Would adjusting s, k, or num_live_points help with compilation time, or is it purely a function of the model graph size?
  4. Are there any best practices for using jaxns with complex lax.scan-based likelihoods?

Thanks for a great library — the 2-dim case works beautifully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions