-
Notifications
You must be signed in to change notification settings - Fork 2.2k
SMC Multiprocessing and Progress Bar Refactor #8047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@tvwenger it would be nice if you could try your SMC models that have been giving you trouble on this PR branch and report back, since you've been the one doing the heavy lifting bug-hunting on SMC lately. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors SMC (Sequential Monte Carlo) sampling to harmonize its API with pm.sample, bringing multiprocessing capabilities, progress bars, and performance improvements to pm.sample_smc. The refactor moves PyTensor compilation to the main process before distributing work to child processes, addresses thread safety concerns, and adds comprehensive progress reporting similar to MCMC sampling.
Changes:
- Implemented multiprocessing support for SMC using a pattern similar to MCMC parallel sampling
- Added custom SMC progress bars that track beta (inverse temperature) progression from 0 to 1
- Moved kernel compilation to main process to avoid thread-safety issues with PyTensor compilation in worker processes
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| pymc/smc/parallel.py | New file implementing parallel SMC sampling infrastructure with process management, message passing, and result collection |
| pymc/smc/sampling.py | Major refactor of main SMC sampling function to support both parallel and sequential execution with shared kernel compilation |
| pymc/smc/kernels.py | Moved kernel compilation to init and added progress bar configuration methods |
| pymc/progress_bar.py | Added SMCProgressBarManager class for beta-based progress tracking and modified table styling |
| tests/smc/test_smc.py | Added test for sequential sampling with cores=1 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8047 +/- ##
==========================================
+ Coverage 90.79% 91.43% +0.64%
==========================================
Files 121 122 +1
Lines 19408 19695 +287
==========================================
+ Hits 17621 18008 +387
+ Misses 1787 1687 -100
🚀 New features to boost your workflow:
|
bf91046 to
7567f0e
Compare
|
Check #8044 |
7567f0e to
8f6f2ad
Compare
I addressed this in the Extract and share Parallel setup code between MCMC and SMC commit, but now the two PRs are overlapping. |
15c4e5f to
ae2d582
Compare
|
So there's an edge case with the pickle function -> send to process approach. If the pickled functions have random number generators these need to de changed so as to have independent streams. Usually this isn't a problem in mcmc because we never wanted to use functions with randomness in it in our step samplers, but this is not the case for SMC, and specially SMC-ABC with Simulator, which definitely supposed to be random. When you call The approach may require something like the |
Could we just make the rng an explicit input to the function we pickle up and send out, to avoid the copy? |
Not without much more changes in the codebase |
|
OK, but I only need to copy once per worker right? After the RNG is split and set we're good? |
I think so |
zaxtax
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. Some of the code could still be simplified
|
@jessegrabowski ping me again when this is stable and I'll test it with my models and workflow. |
It should work as-is, what I'll do from here is mostly refactors and cleanup. I'm mostly curious if you have a chonky model you're fitting using SMC currently, if you can use this branch and do |
8ec2475 to
8b752db
Compare
8b752db to
dc33eeb
Compare
Extract mp_ctx initialization function Extract blas_core setup function Don't use threadpool limits when mp_ctx is ForkContext
dc33eeb to
56cc72d
Compare
| _SMCProcess(*args).run() | ||
|
|
||
|
|
||
| class SMCProcessAdapter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiprocess code is tricky. We should have unit test coverage for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you suggest for MP specifically? All existing SMC tests pass through this code.
| def joined_blas_limiter(): | ||
| return threadpool_limits(limits=blas_cores) | ||
|
|
||
| num_blas_cores_per_chain = blas_cores // cores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should cores be chains here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, this was pre-existing code. I don't want to introduce changes to MCMC in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cores here is already <= to chains, I think this is right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct name would be num_blas_cores_per_worker or smth (chains will wait while pool is exhausted)
|
|
||
|
|
||
| def compute_draw_speed(elapsed: float, draws: int) -> tuple[float, str]: | ||
| def compute_draw_speed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline for the smc progressbar and revert the changes
or inline for everyone... hate this helper
| "log_marginal_likelihood": (float, []), | ||
| "beta": (float, []), | ||
| } | ||
| """Maps stat names to (dtype, shape) tuples.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Maps stat names to (dtype, shape) tuples.""" |
| names = [model.rvs_to_values[rv].name for rv in model.free_RVs] | ||
| dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)} | ||
| prior_values = draw(self._prior_expression, draws=self.draws, random_seed=self.rng) | ||
| dict_prior = {k: np.stack(v) for k, v in zip(self._prior_var_names, prior_values)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this np.stack is a (costly) no-op, and it will fail for scalars. Try to remove?
| else threadpool_limits(limits=self._blas_cores) | ||
| ): | ||
| try: | ||
| self._unpickle_objects() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think _unpickle_objcts need not be inside this context
Description
I always like SMC as a gradient-free option for my big silly models with few parameters, but it always gave me trouble because of the API break between it and
pm.sample. This PR aims to harmonize the two by bringing over a bunch of functionality frompm.sampletopm.sample_smc.This PR is intended to be reviewed commit by commit. I verified that the test suite runs in all intermediate forms. Here is a summary of each commit:
Use multiprocessing for SMC sampling: Themultiprocessinglibrary is now used to handle parallel SMC sampling. This commit was heavily Claude-assisted, so it should receive special scrutiny. The objective was to make SMC multiprocessing look exactly like MCMC multiprocessing. It also exposes anmp_ctxargument topm.sample_smc, which can allow compiling with e.g. JAX (usingmp_ctx ='forkserver').Sample SMC sequentially when cores=1adds separate logic for sequential sampling on one core. Again, this copies the relevant MCMC functions.Initialize SMC Kernels on main processis a major performance change, intended to address e.g. BUG:sample_smcstalls on final stage #8030. Pytensor compilation is not thread-safe, so we shouldn't be doing it on the workers. In this PR, the kernel is compiled once on the main process, then serialized and sent to the workers. This matches what we do with step functions in MCMC. Importantly, this commit eliminates the need for serialization of many auxiliary objects, including the pymc model itself, and some special logic for custom distributions. To do this, a couple of ancillary changes had to be made -- for example, transformation of the chains from numpy to NDArray objects happens on the main process now, after all sampling is done.Add blas_cores argument to sample_smcagain, this copies over multiprocessing machinery frompm.sampletopm.sample_smcby adding ablas_coresargument topm.sample_smc, for the same reasons it exists over there.Add custom progress bar for SMCadds a progress bar style tosample_smcthat matches that ofpm.sample. The bars fill from 0-1 following the value ofbeta, and we provide an estimated time to completion by measuring the speed per step. It looks like this:I observed big speed gains using
sample_smcafter this PR. I timed this simple hierarchical model:Timings went from 6.1 s to 1.44 s using the C backend, and 1.46 s to 1.09 s using Numba mode (with cache). Running
test_smc.pylocally goes from 1m4s to 6.264 seconds.I could run more formal benchmarks if someone asks, but I don't really want to.
Related Issue
sample_smcstalls on final stage #8030sample_smccan lead to compilation halting #8022Checklist
Type of change