Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 10, 2026

Description

I always like SMC as a gradient-free option for my big silly models with few parameters, but it always gave me trouble because of the API break between it and pm.sample. This PR aims to harmonize the two by bringing over a bunch of functionality from pm.sample to pm.sample_smc.

This PR is intended to be reviewed commit by commit. I verified that the test suite runs in all intermediate forms. Here is a summary of each commit:

  • Use multiprocessing for SMC sampling: The multiprocessing library is now used to handle parallel SMC sampling. This commit was heavily Claude-assisted, so it should receive special scrutiny. The objective was to make SMC multiprocessing look exactly like MCMC multiprocessing. It also exposes an mp_ctx argument to pm.sample_smc, which can allow compiling with e.g. JAX (using mp_ctx ='forkserver').
  • Sample SMC sequentially when cores=1 adds separate logic for sequential sampling on one core. Again, this copies the relevant MCMC functions.
  • Initialize SMC Kernels on main process is a major performance change, intended to address e.g. BUG: sample_smc stalls on final stage #8030. Pytensor compilation is not thread-safe, so we shouldn't be doing it on the workers. In this PR, the kernel is compiled once on the main process, then serialized and sent to the workers. This matches what we do with step functions in MCMC. Importantly, this commit eliminates the need for serialization of many auxiliary objects, including the pymc model itself, and some special logic for custom distributions. To do this, a couple of ancillary changes had to be made -- for example, transformation of the chains from numpy to NDArray objects happens on the main process now, after all sampling is done.
  • Add blas_cores argument to sample_smc again, this copies over multiprocessing machinery from pm.sample to pm.sample_smc by adding a blas_cores argument to pm.sample_smc, for the same reasons it exists over there.
  • Add custom progress bar for SMC adds a progress bar style to sample_smc that matches that of pm.sample. The bars fill from 0-1 following the value of beta, and we provide an estimated time to completion by measuring the speed per step. It looks like this:
 Progress                                   Stage   Beta     Stage Speed    Elapsed   Remaining 
────────────────────────────────────────────────────────────────────────────────────────────────
 ━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0620   4.66 s/stage   0:00:17   0:03:46   
 ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0634   5.42 s/stage   0:00:17   0:04:18   
 ━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   2       0.0269   6.15 s/stage   0:00:17   0:09:29   
 ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0646   5.38 s/stage   0:00:17   0:04:10   
 ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0658   5.16 s/stage   0:00:17   0:03:55   
 ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0658   5.55 s/stage   0:00:17   0:04:12   
 ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3       0.0639   4.83 s/stage   0:00:17   0:03:46   

I observed big speed gains using sample_smc after this PR. I timed this simple hierarchical model:

import pymc as pm

with pm.Model() as m:
    idxs = pm.draw(pm.Categorical.dist(p=[1] * 10, shape=(100,)))
    effect_loc = pm.Normal('mu_loc', 0, 1)
    effect_scale = pm.HalfNormal('mu_scale', 1)
    effect = pm.Normal('mu', mu=effect_loc, sigma=effect_scale, shape=(10,))
    
    X = pm.draw(pm.Normal.dist(0, 1, shape=(100, 5)))
    beta = pm.Normal('beta', 0, 1, shape=(5,))
    
    mu = effect[idxs] + X @ beta
    sigma = pm.Exponential('sigma', 1)
    
    obs = pm.Normal('obs', mu=mu, sigma=sigma)
    prior = pm.sample_prior_predictive()

draw = prior.prior.obs.sel(chain=0, draw=123).values
m2 = pm.observe(m, {obs:draw})

%%time 
with m2:
    idata = pm.sample_smc()

Timings went from 6.1 s to 1.44 s using the C backend, and 1.46 s to 1.09 s using Numba mode (with cache). Running test_smc.py locally goes from 1m4s to 6.264 seconds.

I could run more formal benchmarks if someone asks, but I don't really want to.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski
Copy link
Member Author

@tvwenger it would be nice if you could try your SMC models that have been giving you trouble on this PR branch and report back, since you've been the one doing the heavy lifting bug-hunting on SMC lately.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors SMC (Sequential Monte Carlo) sampling to harmonize its API with pm.sample, bringing multiprocessing capabilities, progress bars, and performance improvements to pm.sample_smc. The refactor moves PyTensor compilation to the main process before distributing work to child processes, addresses thread safety concerns, and adds comprehensive progress reporting similar to MCMC sampling.

Changes:

  • Implemented multiprocessing support for SMC using a pattern similar to MCMC parallel sampling
  • Added custom SMC progress bars that track beta (inverse temperature) progression from 0 to 1
  • Moved kernel compilation to main process to avoid thread-safety issues with PyTensor compilation in worker processes

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
pymc/smc/parallel.py New file implementing parallel SMC sampling infrastructure with process management, message passing, and result collection
pymc/smc/sampling.py Major refactor of main SMC sampling function to support both parallel and sequential execution with shared kernel compilation
pymc/smc/kernels.py Moved kernel compilation to init and added progress bar configuration methods
pymc/progress_bar.py Added SMCProgressBarManager class for beta-based progress tracking and modified table styling
tests/smc/test_smc.py Added test for sequential sampling with cores=1

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Jan 10, 2026

Codecov Report

❌ Patch coverage is 82.78146% with 78 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.43%. Comparing base (2afaa49) to head (56cc72d).

Files with missing lines Patch % Lines
pymc/smc/parallel.py 66.66% 68 Missing ⚠️
pymc/smc/sampling.py 95.23% 4 Missing ⚠️
pymc/sampling/mcmc.py 90.47% 2 Missing ⚠️
pymc/smc/kernels.py 97.50% 2 Missing ⚠️
pymc/progress_bar/progress.py 98.00% 1 Missing ⚠️
pymc/sampling/parallel.py 90.90% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8047      +/-   ##
==========================================
+ Coverage   90.79%   91.43%   +0.64%     
==========================================
  Files         121      122       +1     
  Lines       19408    19695     +287     
==========================================
+ Hits        17621    18008     +387     
+ Misses       1787     1687     -100     
Files with missing lines Coverage Δ
pymc/progress_bar/__init__.py 100.00% <ø> (ø)
pymc/progress_bar/utils.py 100.00% <100.00%> (ø)
pymc/progress_bar/progress.py 96.50% <98.00%> (+0.71%) ⬆️
pymc/sampling/parallel.py 63.54% <90.90%> (+0.59%) ⬆️
pymc/sampling/mcmc.py 91.23% <90.47%> (-0.04%) ⬇️
pymc/smc/kernels.py 97.27% <97.50%> (+50.17%) ⬆️
pymc/smc/sampling.py 96.59% <95.23%> (+19.51%) ⬆️
pymc/smc/parallel.py 66.66% <66.66%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

Check #8044

@jessegrabowski
Copy link
Member Author

Check #8044

I addressed this in the Extract and share Parallel setup code between MCMC and SMC commit, but now the two PRs are overlapping.

@jessegrabowski jessegrabowski force-pushed the smc-refactor branch 3 times, most recently from 15c4e5f to ae2d582 Compare January 11, 2026 01:01
@jessegrabowski jessegrabowski requested a review from zaxtax January 12, 2026 00:41
@ricardoV94
Copy link
Member

So there's an edge case with the pickle function -> send to process approach. If the pickled functions have random number generators these need to de changed so as to have independent streams.

Usually this isn't a problem in mcmc because we never wanted to use functions with randomness in it in our step samplers, but this is not the case for SMC, and specially SMC-ABC with Simulator, which definitely supposed to be random.

When you call model.logp() in a model with a Simulator the logp as RandomGenerator variables in it. These are unique everytime you called model.logp(), so the compile in each process didn't have the duplication issue (entropy properties may not have been great, but that's a different matter). Here they'll be exactly the same.

The approach may require something like the set_rng code added for the conjugate samplers in this PR: https://github.com/ricardoV94/pymc-extras/blob/35530daf55a8ae7f5cbcf5ae00760da32abf560a/pymc_extras/sampling/optimizations/conjugate_sampler.py#L84-L98, and make sure it's called at the beginning of sampling in each thread

@jessegrabowski
Copy link
Member Author

The approach may require something like the set_rng code added for the conjugate samplers in this PR: https://github.com/ricardoV94/pymc-extras/blob/35530daf55a8ae7f5cbcf5ae00760da32abf560a/pymc_extras/sampling/optimizations/conjugate_sampler.py#L84-L98, and make sure it's called at the beginning of sampling in each thread

Could we just make the rng an explicit input to the function we pickle up and send out, to avoid the copy?

@ricardoV94
Copy link
Member

Could we just make the rng an explicit input to the function we pickle up and send out, to avoid the copy?

Not without much more changes in the codebase

@jessegrabowski
Copy link
Member Author

OK, but I only need to copy once per worker right? After the RNG is split and set we're good?

@ricardoV94
Copy link
Member

OK, but I only need to copy once per worker right? After the RNG is split and set we're good?

I think so

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. Some of the code could still be simplified

@tvwenger
Copy link
Contributor

@jessegrabowski ping me again when this is stable and I'll test it with my models and workflow.

@jessegrabowski
Copy link
Member Author

@jessegrabowski ping me again when this is stable and I'll test it with my models and workflow.

It should work as-is, what I'll do from here is mostly refactors and cleanup. I'm mostly curious if you have a chonky model you're fitting using SMC currently, if you can use this branch and do pm.sample_smc(compile_kwargs={'mode':'NUMBA'}) to get a free speedup (and verify that everything works)

Extract mp_ctx initialization function

Extract blas_core setup function

Don't use threadpool limits when mp_ctx is ForkContext
_SMCProcess(*args).run()


class SMCProcessAdapter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiprocess code is tricky. We should have unit test coverage for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you suggest for MP specifically? All existing SMC tests pass through this code.

def joined_blas_limiter():
return threadpool_limits(limits=blas_cores)

num_blas_cores_per_chain = blas_cores // cores
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should cores be chains here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, this was pre-existing code. I don't want to introduce changes to MCMC in this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cores here is already <= to chains, I think this is right

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct name would be num_blas_cores_per_worker or smth (chains will wait while pool is exhausted)



def compute_draw_speed(elapsed: float, draws: int) -> tuple[float, str]:
def compute_draw_speed(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline for the smc progressbar and revert the changes

or inline for everyone... hate this helper

"log_marginal_likelihood": (float, []),
"beta": (float, []),
}
"""Maps stat names to (dtype, shape) tuples."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Maps stat names to (dtype, shape) tuples."""

names = [model.rvs_to_values[rv].name for rv in model.free_RVs]
dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)}
prior_values = draw(self._prior_expression, draws=self.draws, random_seed=self.rng)
dict_prior = {k: np.stack(v) for k, v in zip(self._prior_var_names, prior_values)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this np.stack is a (costly) no-op, and it will fail for scalars. Try to remove?

else threadpool_limits(limits=self._blas_cores)
):
try:
self._unpickle_objects()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _unpickle_objcts need not be inside this context

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

4 participants