Speedup Pathfinder compilation and wall-time#681
Conversation
ffa6d84 to
e3b239d
Compare
Fold the blackjax variant into pathfinder.py, move the test tree under tests/inference/pathfinder to mirror the package, and fix the CI ignore path.
…ttern Replace the ProcessPoolExecutor + Manager-pipe + listener-thread setup with raw worker processes over one bidirectional pipe each, multiplexed with connection.wait(), capped at `cores` live workers. A worker that errors or dies yields a failed result instead of aborting the whole fit.
SinglePathfinderFn becomes a Protocol matching the real (random_seed, progress_callback) signature; progress task ids are typed TaskID.
minimize_streaming calls better_optimize.minimize (fused objective, flat LBFGSConfig kwargs); the streaming callback takes an OptimizeResult and uses res.jac. Progress shows bar + iter + step pace + best ELBO.
c2f1cb9 to
dd0880f
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #681 +/- ##
===========================================
+ Coverage 51.60% 90.93% +39.33%
===========================================
Files 73 78 +5
Lines 8003 7956 -47
===========================================
+ Hits 4130 7235 +3105
+ Misses 3873 721 -3152
🚀 New features to boost your workflow:
|
|
|
||
| logp_dlogp_kwargs = {"jacobian": jacobian_correction, **compile_kwargs} | ||
|
|
||
| neg_logp_dlogp_func = get_neg_logp_dlogp_of_ravel_inputs(model, **logp_dlogp_kwargs) |
There was a problem hiding this comment.
I believe this line is the main source of speedup. Here is the pre-refactor version. The .copy triggers a re-compilation on each thread, which ends up as a dominating cost for smaller models.
There was a problem hiding this comment.
copy does not trigger re-compilation on each thread (checking again...), or it should not. I assume you are on macOS with forkserver by default, and you may be seeing this instead: pymc-devs/pytensor#2108
There was a problem hiding this comment.
ok it does call the dispatch stuff again, we should fix that in pytensor. #2108 may also be needed. Not saying we can't use this approach here, but there's no fundamental reason why .copy should be avoided
There was a problem hiding this comment.
The only thing is the share_memory concern, if there are ever RNGs in the compiled function we don't want those reused across the processes. It can happen naturally with deterministics, or minibatch. PyMC reseeds the functions after forking I think?
And are there symbolic minimize ops in this graph or better optimize stuff, and does that have the LRU cache stuff and could that matter / not be correctly deduplicated. This is a pre-existing concern I would have, not specific about your refactor
There was a problem hiding this comment.
Ok the rng part maybe not new, the share_memory is not what would fix it, but the swap={old_rng: new_rng} argument
There was a problem hiding this comment.
maybe just raise if there are RandomType shared variables in the function? Then I'm not worried about memory (only the better_optimize thing)
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def multipath_pathfinder( |
There was a problem hiding this comment.
This was rewritten to closely copy what we're doing in pymc with the mcmc/smc parallel code. I am working under the assumption that those represent best practices for MP with pytensor.
| neg_logp_dlogp_func = get_neg_logp_dlogp_of_ravel_inputs(model, **logp_dlogp_kwargs) | ||
|
|
||
| # initial point | ||
| # TODO: remove make_initial_points function when feature request is implemented: |
There was a problem hiding this comment.
this todo is outdated or unclear
|
Very customly prompted review from the bot. As per usual it's likely that half is crap: Introduced by this PR1. [High] JAX + fork can deadlock — 2. [Med] Ctrl-C discards all completed paths — 3. [Med] 4. [Med-Low] Per-path ordering is nondeterministic in parallel — 5. [Low-Med] Pre-existing (carried into touched files — flag to author, not introduced here)6. [High if intended-raw] In-place mutation of 7. [Med] 8. [Med] Uncaught 9. [Low] Wrong nonzero count in the fallback warning — Cleanup (minor, optional)
|
fork + JAX can deadlock; pass the compile mode so the context resolver switches off fork, mirroring pymc.sample.
setup_cores_blas_cores' parent-side limiter was discarded; obtain it and wrap path execution so the parent caps BLAS threads, matching pymc.
Workers finish out of order, so a fixed seed gave permuted chains and per-path diagnostics; yield (chain, result) pairs and sort by chain before aggregating.
|
The failure is fixed by pymc-devs/pytensor#2191 |
|
cut a release? |
wip
Timings
models/api_quickstart_logistic.pymodels/bayes_factor_uniform_prior.pymodels/bayes_factor_concentrated_prior.pymodels/ar1.pymodels/data_container_linear.pymodels/ar2.pymodels/data_container_logistic.pymodels/bayes_param_survival_loglogistic.pymodels/longitudinal_unconditional_mean.pymodels/GLM_model_selection_ols.pymodels/demetropolisz_efficiency_mvnormal.pymodels/GLM_binomial_regression.pymodels/GLM_out_of_sample_logistic.pymodels/sampler_stats_normal.pymodels/demetropolisz_tune_drop_fraction_mvnormal.pymodels/bayes_param_survival_weibull.pymodels/model_averaging_abdomen.pymodels/moderation_analysis.pymodels/GLM_robust_ols.pymodels/ar2_individual_priors.pymodels/data_container_temperatures.pymodels/GLM_robust_normal.pymodels/multilevel_pooled.pymodels/model_averaging_multivariable.pymodels/GLM_missing_values_baseline.pymodels/sr05_marriage_divorce.pymodels/simpsons_paradox_pooled.pymodels/air_passengers_prophet.pymodels/sr07_studentt_robust.pymodels/lasso_block_update.pymodels/GLM_robust_studentt.pymodels/longitudinal_unconditional_growth.pymodels/GLM_poisson_regression.pymodels/eight_schools_noncentered.pymodels/hypothesis_testing_normal_mean.pymodels/weibull_aft_param1.pymodels/sr03_linear_regression.pymodels/multilevel_partial_pooling.pymodels/interrupted_time_series.pymodels/simpsons_paradox_unpooled.pymodels/GLM_robust_studentt_outlier.pymodels/multilevel_unpooled.pymodels/sr10_innovation_loss.pymodels/copula_marginal.pymodels/difference_in_differences.pymodels/eight_schools_centered.pymodels/GLM_discrete_choice_basic.pymodels/spline_cherry_blossoms.pymodels/GLM_negative_binomial_regression.pymodels/smc2_gaussians_n4.pymodels/multilevel_contextual_effect.pymodels/spline_cherry_blossoms_data_container.pymodels/smc2_gaussians_n80.pymodels/mediation_analysis.pymodels/sr19_cylinder_body.pymodels/model_builder_linear.pymodels/weibull_aft_param2.pymodels/counterfactuals_do_operator.pymodels/gp_marginal_matern52.pymodels/simpsons_paradox_partial_pooling.pymodels/GLM_missing_values_auto_impute.pymodels/longitudinal_external_minimal.pymodels/multilevel_varying_intercept.pymodels/GLM_ordinal_regression_constrained.pymodels/forecasting_structural_ar.pymodels/sr08_wine_judges.pymodels/GLM_hierarchical_binomial_rat_tumor.pymodels/variational_quickstart_iris.pymodels/multilevel_hierarchical_intercept.pymodels/sr12_tadpole_predator.pymodels/frailty_loglogistic_aft.pymodels/sr04_full_luxury_bayes.pymodels/variational_quickstart_gamma.pymodels/forecasting_structural_ar_trend.pymodels/sr09_cat_adoption_censored.pymodels/reliability_bearing_cage_weibull_informative.pymodels/bayesian_neural_network_advi.pymodels/multilevel_varying_intercept_slope_noncentered.pymodels/longitudinal_coa_growth.pymodels/updating_priors_linear.pymodels/rugby_analytics.pymodels/GLM_discrete_choice_intercepts.pymodels/GLM_ordinal_regression.pymodels/sr13_district_urban_contraception.pymodels/BEST.pymodels/longitudinal_external_polynomial_gender.pymodels/ab_testing_revenue.pymodels/survival_analysis_coxph.pymodels/longitudinal_coa_peer_growth.pymodels/sr12_latent_mundlak.pymodels/regression_discontinuity.pymodels/GLM_truncated_regression.pymodels/euler_maruyama_linear_sde.pymodels/reliability_bearing_cage_weibull_uninformative.pymodels/forecasting_structural_ar_trend_seasonal.pymodels/dirichlet_mixture_of_multinomials.pymodels/gp_smoothing_grw.pymodels/sr16_gp_oceanic_tools.pymodels/weibull_aft_param3.pymodels/malaria_hsgp.pymodels/hierarchical_partial_pooling.pymodels/missing_data_mvnormal_potential.pymodels/frailty_coxph.pymodels/glm_hierarchical_advi_minibatch.pymodels/frailty_weibull_aft.pymodels/fast_sampling_ppca.pymodels/ode_lotka_volterra_pytensor_scan.pymodels/sr17_measurement_error_divorce.pymodels/lkj_cholesky_cov_mvnormal.pymodels/item_response_nba.pymodels/reinforcement_learning_potential.pymodels/gp_numpy_kernel.pymodels/missing_data_chained_normal.pymodels/missing_data_hierarchical_team.pymodels/GLM_discrete_choice_correlated.pymodels/GLM_censored_regression.pymodels/copula_gaussian.pymodels/survival_analysis_time_varying.pymodels/factor_analysis_ppca.pymodels/gp_latent_studentt.pymodels/mogp_lcm.pymodels/censored_data_unimputed.pymodels/sr14_correlated_varying_effects.pymodels/gp_kron_marginal.pymodels/reinforcement_learning_bernoulli.pymodels/dp_mix_sunspot.pymodels/missing_data_chained_uniform.pymodels/frailty_coxph_shared.pymodels/blackbox_external_likelihood_potential.pymodels/sr16_phylogenetic_regression.pymodels/excess_deaths.pymodels/GLM_discrete_choice_mixed_logit.pymodels/bayesian_workflow_logistic.pymodels/sr15_social_network_giving_receiving.pymodels/mv_gaussian_random_walk.pymodels/nyc_bym_traffic.pymodels/time_series_generative_graph_ar2.pymodels/sampling_conjugate_step.pymodels/blackbox_external_likelihood_with_grad.pymodels/GLM_rolling_regression.pymodels/gp_log_gaussian_cox.pymodels/ode_api_enzymatic_reaction.pymodels/ode_api_freefall.pymodels/frailty_coxph_individual.pymodels/binning_hierarchical.pymodels/ode_api_sir.pymodels/gp_births_hsgp.pymodels/bayesian_var_fake.pymodels/stochastic_volatility.pymodels/sr18_missing_data_primates.pymodels/ode_lotka_volterra_pymc_ode.pymodels/gp_mauna_loa_co2.pymodels/bayesian_var_hierarchical.pymodels/probabilistic_matrix_factorization.pymodels/CFA_SEM_indirect.py