Skip to content

Speedup Pathfinder compilation and wall-time#681

Open
jessegrabowski wants to merge 36 commits into
pymc-devs:mainfrom
jessegrabowski:pathfinder-memory
Open

Speedup Pathfinder compilation and wall-time#681
jessegrabowski wants to merge 36 commits into
pymc-devs:mainfrom
jessegrabowski:pathfinder-memory

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented May 8, 2026

wip

Timings

model main (scan) s branch (scan) s branch (vectorize) s
models/api_quickstart_logistic.py 31.2 2.3 (13.3×) 0.7 (44.6×)
models/bayes_factor_uniform_prior.py 28.4 1.9 (14.6×) 0.8 (37.9×)
models/bayes_factor_concentrated_prior.py 26.0 2.0 (12.9×) 0.8 (34.7×)
models/ar1.py 29.8 2.5 (12.2×) 0.9 (32.4×)
models/data_container_linear.py 22.9 2.0 (11.3×) 0.7 (31.4×)
models/ar2.py 29.2 2.3 (12.5×) 0.9 (31.1×)
models/data_container_logistic.py 23.4 2.1 (11.1×) 0.8 (30.7×)
models/bayes_param_survival_loglogistic.py 27.8 2.5 (11.2×) 0.9 (30.2×)
models/longitudinal_unconditional_mean.py 29.4 3.1 (9.6×) 1.1 (27.0×)
models/GLM_model_selection_ols.py 26.6 2.3 (11.7×) 1.0 (26.8×)
models/demetropolisz_efficiency_mvnormal.py 18.7 2.4 (7.8×) 0.7 (26.4×)
models/GLM_binomial_regression.py 21.9 2.5 (8.7×) 0.8 (26.4×)
models/GLM_out_of_sample_logistic.py 22.1 2.3 (9.4×) 0.8 (26.0×)
models/sampler_stats_normal.py 16.1 1.8 (8.9×) 0.6 (25.6×)
models/demetropolisz_tune_drop_fraction_mvnormal.py 18.1 1.7 (10.8×) 0.7 (25.1×)
models/bayes_param_survival_weibull.py 22.5 2.4 (9.5×) 0.9 (24.0×)
models/model_averaging_abdomen.py 18.3 2.0 (8.9×) 0.8 (23.8×)
models/moderation_analysis.py 21.4 2.3 (9.4×) 0.9 (23.6×)
models/GLM_robust_ols.py 20.1 2.0 (10.1×) 0.9 (23.1×)
models/ar2_individual_priors.py 28.6 2.5 (11.3×) 1.2 (23.1×)
models/data_container_temperatures.py 19.9 2.3 (8.6×) 0.9 (22.9×)
models/GLM_robust_normal.py 21.4 2.1 (9.9×) 1.0 (21.2×)
models/multilevel_pooled.py 16.7 2.1 (7.9×) 0.8 (21.2×)
models/model_averaging_multivariable.py 17.4 2.2 (7.8×) 0.8 (21.0×)
models/GLM_missing_values_baseline.py 24.7 2.5 (9.7×) 1.2 (20.8×)
models/sr05_marriage_divorce.py 18.5 2.1 (8.7×) 0.9 (20.5×)
models/simpsons_paradox_pooled.py 17.9 2.2 (8.0×) 0.9 (20.3×)
models/air_passengers_prophet.py 24.4 3.4 (7.3×) 1.2 (20.0×)
models/sr07_studentt_robust.py 18.1 2.3 (7.9×) 0.9 (19.9×)
models/lasso_block_update.py 19.0 2.5 (7.6×) 1.0 (19.8×)
models/GLM_robust_studentt.py 21.1 2.3 (9.3×) 1.1 (19.7×)
models/longitudinal_unconditional_growth.py 28.5 3.5 (8.3×) 1.4 (19.6×)
models/GLM_poisson_regression.py 20.2 2.3 (8.8×) 1.0 (19.6×)
models/eight_schools_noncentered.py 16.7 2.1 (8.0×) 0.9 (19.4×)
models/hypothesis_testing_normal_mean.py 14.2 1.9 (7.3×) 0.7 (19.1×)
models/weibull_aft_param1.py 16.4 2.3 (7.1×) 0.9 (18.3×)
models/sr03_linear_regression.py 18.9 2.5 (7.6×) 1.1 (17.8×)
models/multilevel_partial_pooling.py 17.6 3.1 (5.7×) 1.0 (17.2×)
models/interrupted_time_series.py 13.9 2.2 (6.3×) 0.8 (17.1×)
models/simpsons_paradox_unpooled.py 17.7 2.5 (7.1×) 1.0 (17.0×)
models/GLM_robust_studentt_outlier.py 25.3 2.8 (9.0×) 1.5 (16.9×)
models/multilevel_unpooled.py 15.3 2.8 (5.4×) 0.9 (16.8×)
models/sr10_innovation_loss.py 16.1 2.5 (6.5×) 1.0 (16.4×)
models/copula_marginal.py 13.8 2.3 (6.1×) 0.8 (16.3×)
models/difference_in_differences.py 15.9 2.4 (6.8×) 1.0 (16.3×)
models/eight_schools_centered.py 13.8 2.2 (6.3×) 0.9 (16.0×)
models/GLM_discrete_choice_basic.py 27.4 3.8 (7.3×) 1.8 (15.5×)
models/spline_cherry_blossoms.py 15.3 2.8 (5.5×) 1.0 (15.4×)
models/GLM_negative_binomial_regression.py 30.7 3.7 (8.2×) 2.0 (15.4×)
models/smc2_gaussians_n4.py 16.1 2.6 (6.1×) 1.1 (15.4×)
models/multilevel_contextual_effect.py 20.4 3.5 (5.9×) 1.3 (15.2×)
models/spline_cherry_blossoms_data_container.py 14.9 2.9 (5.1×) 1.0 (15.1×)
models/smc2_gaussians_n80.py 16.3 2.9 (5.6×) 1.1 (15.0×)
models/mediation_analysis.py 17.0 2.6 (6.6×) 1.1 (14.8×)
models/sr19_cylinder_body.py 13.7 2.4 (5.7×) 0.9 (14.7×)
models/model_builder_linear.py 11.7 2.1 (5.6×) 0.8 (14.7×)
models/weibull_aft_param2.py 13.2 2.2 (5.9×) 0.9 (14.5×)
models/counterfactuals_do_operator.py 13.3 2.4 (5.6×) 0.9 (14.2×)
models/gp_marginal_matern52.py 20.1 3.9 (5.2×) 1.4 (13.9×)
models/simpsons_paradox_partial_pooling.py 18.7 2.9 (6.4×) 1.4 (13.6×)
models/GLM_missing_values_auto_impute.py 27.3 4.3 (6.4×) 2.0 (13.6×)
models/longitudinal_external_minimal.py 16.6 3.4 (4.9×) 1.2 (13.5×)
models/multilevel_varying_intercept.py 14.8 3.1 (4.7×) 1.1 (13.2×)
models/GLM_ordinal_regression_constrained.py 28.4 4.3 (6.5×) 2.2 (12.9×)
models/forecasting_structural_ar.py 15.6 3.5 (4.5×) 1.2 (12.9×)
models/sr08_wine_judges.py 15.8 3.2 (4.9×) 1.2 (12.9×)
models/GLM_hierarchical_binomial_rat_tumor.py 22.9 4.3 (5.3×) 1.8 (12.5×)
models/variational_quickstart_iris.py 17.7 3.1 (5.8×) 1.4 (12.4×)
models/multilevel_hierarchical_intercept.py 16.6 3.4 (4.9×) 1.4 (12.2×)
models/sr12_tadpole_predator.py 14.7 3.1 (4.7×) 1.2 (12.1×)
models/frailty_loglogistic_aft.py 14.5 2.9 (5.0×) 1.2 (12.1×)
models/sr04_full_luxury_bayes.py 20.6 3.4 (6.1×) 1.7 (11.9×)
models/variational_quickstart_gamma.py 10.8 2.3 (4.6×) 0.9 (11.7×)
models/forecasting_structural_ar_trend.py 15.4 3.6 (4.2×) 1.3 (11.6×)
models/sr09_cat_adoption_censored.py 17.4 2.8 (6.2×) 1.5 (11.6×)
models/reliability_bearing_cage_weibull_informative.py 12.2 2.8 (4.4×) 1.1 (11.2×)
models/bayesian_neural_network_advi.py 15.7 3.4 (4.6×) 1.4 (10.8×)
models/multilevel_varying_intercept_slope_noncentered.py 14.0 3.3 (4.2×) 1.3 (10.7×)
models/longitudinal_coa_growth.py 17.8 3.7 (4.8×) 1.7 (10.7×)
models/updating_priors_linear.py 9.7 2.2 (4.4×) 0.9 (10.6×)
models/rugby_analytics.py 16.1 3.4 (4.7×) 1.5 (10.5×)
models/GLM_discrete_choice_intercepts.py 23.1 4.2 (5.5×) 2.2 (10.3×)
models/GLM_ordinal_regression.py 27.9 4.6 (6.0×) 2.8 (10.1×)
models/sr13_district_urban_contraception.py 13.7 3.3 (4.2×) 1.4 (10.0×)
models/BEST.py 20.5 3.7 (5.6×) 2.1 (9.8×)
models/longitudinal_external_polynomial_gender.py 29.9 5.3 (5.6×) 3.1 (9.8×)
models/ab_testing_revenue.py 33.1 2.9 (11.3×) 3.4 (9.7×)
models/survival_analysis_coxph.py 10.0 3.3 (3.0×) 1.0 (9.7×)
models/longitudinal_coa_peer_growth.py 18.1 4.0 (4.5×) 1.9 (9.6×)
models/sr12_latent_mundlak.py 14.8 3.6 (4.1×) 1.6 (9.4×)
models/regression_discontinuity.py 7.4 2.2 (3.3×) 0.8 (9.2×)
models/GLM_truncated_regression.py 28.6 2.7 (10.7×) 3.1 (9.1×)
models/euler_maruyama_linear_sde.py 14.4 4.1 (3.5×) 1.6 (9.1×)
models/reliability_bearing_cage_weibull_uninformative.py 11.3 2.8 (4.0×) 1.3 (8.9×)
models/forecasting_structural_ar_trend_seasonal.py 13.6 3.8 (3.6×) 1.6 (8.7×)
models/dirichlet_mixture_of_multinomials.py 16.4 3.5 (4.7×) 1.9 (8.6×)
models/gp_smoothing_grw.py 11.1 3.6 (3.1×) 1.3 (8.5×)
models/sr16_gp_oceanic_tools.py 15.1 3.4 (4.4×) 1.8 (8.5×)
models/weibull_aft_param3.py 7.0 2.2 (3.2×) 0.8 (8.4×)
models/malaria_hsgp.py 23.9 6.1 (4.0×) 2.9 (8.3×)
models/hierarchical_partial_pooling.py 19.2 4.7 (4.1×) 2.3 (8.3×)
models/missing_data_mvnormal_potential.py 18.5 6.5 (2.8×) 2.3 (8.1×)
models/frailty_coxph.py 16.4 4.1 (4.0×) 2.0 (8.0×)
models/glm_hierarchical_advi_minibatch.py 13.8 3.9 (3.5×) 1.7 (8.0×)
models/frailty_weibull_aft.py 11.1 3.1 (3.6×) 1.4 (7.9×)
models/fast_sampling_ppca.py 17.5 5.2 (3.4×) 2.3 (7.6×)
models/ode_lotka_volterra_pytensor_scan.py 19.4 4.2 (4.6×) 2.6 (7.5×)
models/sr17_measurement_error_divorce.py 12.6 3.7 (3.4×) 1.7 (7.5×)
models/lkj_cholesky_cov_mvnormal.py 20.0 4.8 (4.2×) 2.8 (7.2×)
models/item_response_nba.py 20.9 4.4 (4.7×) 2.9 (7.2×)
models/reinforcement_learning_potential.py 15.2 4.7 (3.2×) 2.2 (6.8×)
models/gp_numpy_kernel.py 9.2 3.5 (2.6×) 1.4 (6.7×)
models/missing_data_chained_normal.py 19.0 5.2 (3.6×) 3.0 (6.3×)
models/missing_data_hierarchical_team.py 20.2 5.6 (3.6×) 3.2 (6.3×)
models/GLM_discrete_choice_correlated.py 23.0 7.3 (3.2×) 3.7 (6.2×)
models/GLM_censored_regression.py 13.8 2.5 (5.5×) 2.3 (6.1×)
models/copula_gaussian.py 12.3 4.2 (3.0×) 2.0 (6.0×)
models/survival_analysis_time_varying.py 8.5 4.0 (2.1×) 1.5 (5.8×)
models/factor_analysis_ppca.py 17.2 3.9 (4.5×) 3.0 (5.8×)
models/gp_latent_studentt.py 22.9 6.3 (3.7×) 4.0 (5.8×)
models/mogp_lcm.py 27.4 7.0 (3.9×) 5.1 (5.4×)
models/censored_data_unimputed.py 11.6 2.3 (4.9×) 2.2 (5.4×)
models/sr14_correlated_varying_effects.py 10.5 4.2 (2.5×) 2.0 (5.3×)
models/gp_kron_marginal.py 10.7 3.6 (3.0×) 2.0 (5.3×)
models/reinforcement_learning_bernoulli.py 12.6 6.0 (2.1×) 2.4 (5.2×)
models/dp_mix_sunspot.py 14.7 5.2 (2.8×) 2.9 (5.1×)
models/missing_data_chained_uniform.py 14.5 5.1 (2.9×) 2.9 (5.1×)
models/frailty_coxph_shared.py 18.5 6.2 (3.0×) 3.8 (4.8×)
models/blackbox_external_likelihood_potential.py 11.2 1.9 (5.8×) 2.3 (4.8×)
models/sr16_phylogenetic_regression.py 11.4 4.5 (2.5×) 2.4 (4.7×)
models/excess_deaths.py 14.6 3.5 (4.2×) 3.2 (4.5×)
models/GLM_discrete_choice_mixed_logit.py 16.8 5.8 (2.9×) 3.8 (4.4×)
models/bayesian_workflow_logistic.py 9.4 3.9 (2.4×) 2.1 (4.4×)
models/sr15_social_network_giving_receiving.py 14.2 7.9 (1.8×) 3.3 (4.2×)
models/mv_gaussian_random_walk.py 21.1 8.5 (2.5×) 5.2 (4.1×)
models/nyc_bym_traffic.py 18.5 7.3 (2.5×) 5.0 (3.7×)
models/time_series_generative_graph_ar2.py 12.4 3.3 (3.7×) 3.4 (3.6×)
models/sampling_conjugate_step.py 20.6 10.9 (1.9×) 5.8 (3.6×)
models/blackbox_external_likelihood_with_grad.py 8.3 1.9 (4.5×) 2.4 (3.5×)
models/GLM_rolling_regression.py 33.5 10.2 (3.3×) 9.8 (3.4×)
models/gp_log_gaussian_cox.py 28.6 10.5 (2.7×) 8.5 (3.4×)
models/ode_api_enzymatic_reaction.py 17.9 5.0 (3.6×) 6.0 (3.0×)
models/ode_api_freefall.py 15.0 4.3 (3.5×) 5.1 (3.0×)
models/frailty_coxph_individual.py 25.2 11.0 (2.3×) 8.8 (2.9×)
models/binning_hierarchical.py 12.5 4.2 (3.0×) 4.7 (2.7×)
models/ode_api_sir.py 14.0 5.5 (2.5×) 6.5 (2.1×)
models/gp_births_hsgp.py 24.1 14.8 (1.6×) 13.3 (1.8×)
models/bayesian_var_fake.py 17.7 8.9 (2.0×) 9.9 (1.8×)
models/stochastic_volatility.py 26.0 18.8 (1.4×) 16.8 (1.5×)
models/sr18_missing_data_primates.py 49.1 38.9 (1.3×) 36.5 (1.3×)
models/ode_lotka_volterra_pymc_ode.py 27.5 20.8 (1.3×) 21.6 (1.3×)
models/gp_mauna_loa_co2.py 53.5 42.8 (1.2×) 43.2 (1.2×)
models/bayesian_var_hierarchical.py 109.5 118.1 (0.9×) 122.2 (0.9×)
models/probabilistic_matrix_factorization.py 52.8 39.7 (1.3×) 66.2 (0.8×)
models/CFA_SEM_indirect.py 34.4 22.7 (1.5×) 54.9 (0.6×)

Fold the blackjax variant into pathfinder.py, move the test tree under
tests/inference/pathfinder to mirror the package, and fix the CI ignore path.
…ttern

Replace the ProcessPoolExecutor + Manager-pipe + listener-thread setup with raw
worker processes over one bidirectional pipe each, multiplexed with
connection.wait(), capped at `cores` live workers. A worker that errors or dies
yields a failed result instead of aborting the whole fit.
SinglePathfinderFn becomes a Protocol matching the real (random_seed,
progress_callback) signature; progress task ids are typed TaskID.
minimize_streaming calls better_optimize.minimize (fused objective, flat
LBFGSConfig kwargs); the streaming callback takes an OptimizeResult and uses
res.jac. Progress shows bar + iter + step pace + best ELBO.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 23, 2026

Codecov Report

❌ Patch coverage is 91.05505% with 78 lines in your changes missing coverage. Please review.
✅ Project coverage is 90.93%. Comparing base (86fac3c) to head (4eaa17f).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pymc_extras/inference/pathfinder/multipath.py 80.98% 31 Missing ⚠️
pymc_extras/inference/pathfinder/pathfinder.py 51.42% 17 Missing ⚠️
pymc_extras/inference/pathfinder/idata.py 94.14% 13 Missing ⚠️
pymc_extras/inference/pathfinder/bfgs_sample.py 93.54% 6 Missing ⚠️
pymc_extras/inference/pathfinder/single_path.py 93.75% 6 Missing ⚠️
pymc_extras/inference/pathfinder/lbfgs.py 96.55% 3 Missing ⚠️
pymc_extras/inference/idata_utils.py 96.55% 1 Missing ⚠️
...extras/inference/pathfinder/importance_sampling.py 96.66% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main     #681       +/-   ##
===========================================
+ Coverage   51.60%   90.93%   +39.33%     
===========================================
  Files          73       78        +5     
  Lines        8003     7956       -47     
===========================================
+ Hits         4130     7235     +3105     
+ Misses       3873      721     -3152     
Files with missing lines Coverage Δ
pymc_extras/inference/__init__.py 100.00% <100.00%> (ø)
pymc_extras/inference/laplace_approx/idata.py 95.45% <100.00%> (+80.74%) ⬆️
pymc_extras/inference/pathfinder/__init__.py 100.00% <100.00%> (ø)
pymc_extras/inference/pathfinder/results.py 100.00% <100.00%> (ø)
pymc_extras/inference/idata_utils.py 96.55% <96.55%> (ø)
...extras/inference/pathfinder/importance_sampling.py 92.42% <96.66%> (+64.55%) ⬆️
pymc_extras/inference/pathfinder/lbfgs.py 97.79% <96.55%> (+47.79%) ⬆️
pymc_extras/inference/pathfinder/bfgs_sample.py 93.54% <93.54%> (ø)
pymc_extras/inference/pathfinder/single_path.py 93.75% <93.75%> (ø)
pymc_extras/inference/pathfinder/idata.py 94.60% <94.14%> (+94.60%) ⬆️
... and 2 more

... and 27 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


logp_dlogp_kwargs = {"jacobian": jacobian_correction, **compile_kwargs}

neg_logp_dlogp_func = get_neg_logp_dlogp_of_ravel_inputs(model, **logp_dlogp_kwargs)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this line is the main source of speedup. Here is the pre-refactor version. The .copy triggers a re-compilation on each thread, which ends up as a dominating cost for smaller models.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy does not trigger re-compilation on each thread (checking again...), or it should not. I assume you are on macOS with forkserver by default, and you may be seeing this instead: pymc-devs/pytensor#2108

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok it does call the dispatch stuff again, we should fix that in pytensor. #2108 may also be needed. Not saying we can't use this approach here, but there's no fundamental reason why .copy should be avoided

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing is the share_memory concern, if there are ever RNGs in the compiled function we don't want those reused across the processes. It can happen naturally with deterministics, or minibatch. PyMC reseeds the functions after forking I think?

And are there symbolic minimize ops in this graph or better optimize stuff, and does that have the LRU cache stuff and could that matter / not be correctly deduplicated. This is a pre-existing concern I would have, not specific about your refactor

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok the rng part maybe not new, the share_memory is not what would fix it, but the swap={old_rng: new_rng} argument

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just raise if there are RandomType shared variables in the function? Then I'm not worried about memory (only the better_optimize thing)

logger = logging.getLogger(__name__)


def multipath_pathfinder(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was rewritten to closely copy what we're doing in pymc with the mcmc/smc parallel code. I am working under the assumption that those represent best practices for MP with pytensor.

neg_logp_dlogp_func = get_neg_logp_dlogp_of_ravel_inputs(model, **logp_dlogp_kwargs)

# initial point
# TODO: remove make_initial_points function when feature request is implemented:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this todo is outdated or unclear

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 28, 2026

Very customly prompted review from the bot. As per usual it's likely that half is crap:

Introduced by this PR

1. [High] JAX + fork can deadlock — multipath.py:270,366
Both _initialize_multiprocessing_context(mp_ctx, quiet=True) calls omit pymc's mode= argument. pymc's signature is (mp_ctx, *, mode=None, quiet=False) and it auto-switches fork→forkserver/spawn when get_mode(mode).linker is a JAXLinker. compile_kwargs={"mode": "JAX"} is documented (idata.py:302) and reachable, but compile_kwargs is never threaded to these call sites. On Linux (default fork) a parallel JAX run gets the unsafe fork+JAX combo pymc otherwise prevents. Fix: thread mode=compile_kwargs.get("mode") through make_generator/_execute_concurrently, mirroring mcmc.py.

2. [Med] Ctrl-C discards all completed paths — multipath.py:186
results = list(generator) replaced the old except (KeyboardInterrupt, StopIteration) … finally: block that aggregated whatever paths had finished. Now an interrupt mid-run propagates out of multipath_pathfinder and throws away every completed path. Fix: restore partial aggregation around the generator drain.

3. [Med] inv_hessian_diag recomputed and discarded every L-BFGS iteration — bfgs_sample.py:190
The old code had a return_inv_hessian_diag flag and a separate cheap ELBO-only sample fn; this PR merged into one function that always emits inv_hessian_diag. The streaming callback (lbfgs.py:302) discards it (_, logQ, logP, _). In the sparse/large-N branch that's an extra Q @ Lchol (O(N·J²)) per accepted step × up to maxiter × paths × retries — wasted exactly where per-iteration cost matters. Dense branch is ~free. Fix: split the ELBO-path output or gate the 4th output.

4. [Med-Low] Per-path ordering is nondeterministic in parallel — multipath.py:324
_execute_concurrently yields in worker-completion order and discards the chain index it holds; from_path_results concatenates per-path arrays in arrival order. With a fixed seed the posterior is still reproducible (psis pools all draws; seeds are index-fixed), but per-path diagnostics (elbo_argmax, lbfgs/niter, inv_hessian_diag, paths_logP) permute run-to-run, and with importance_sampling=None the chain labels permute. Fix: yield/sort by the held chain index.

5. [Low-Med] PathInvalidLogP is raised inside the retry loop but never retried — single_path.py:191
The inner handler catches only (LBFGSInitFailed, SingleStepPathException), so PathInvalidLogP (and post-loop PathInvalidLogQ) fall through to the outer except PathException with no re-jitter — 2 of 4 jitter-sensitive failure modes silently bypass max_init_retries. Either retry them or move the raise out of the loop to make intent clear.

Pre-existing (carried into touched files — flag to author, not introduced here)

6. [High if intended-raw] In-place mutation of logP/logQimportance_sampling.py:103-104
logP = logP.ravel(); logP -= log_I mutates the caller's contiguous MultiPathfinderResult.logP in place (ravel returns a view), and with_importance_sampling keeps the same object, so every logP/logQ diagnostic in the pathfinder idata group is shifted by -log(num_paths) (reproduced: exactly log(2) for 2 paths). Verbatim in base. Fix: operate on a copy (logP = logP.ravel() - log_I).

7. [Med] num_successful_paths reports the draw count after importance sampling — idata.py:130
result.samples.shape[0] is num_draws once psis/psir/identity collapse samples to (num_draws, N); reproduced as 1000 vs a path coord of 4. The module already has _determine_num_paths (uses lbfgs_niter) for exactly this reason. Pre-existing.

8. [Med] Uncaught ValueError when num_draws exceeds available population — importance_sampling.py:147
For replace=False methods, if successful draws < num_draws, numpy raises Cannot take a larger sample than population; the fallback only matches "Fewer non-zero entries…". Reproduced (pop=100, num_draws=500). Pre-existing.

9. [Low] Wrong nonzero count in the fallback warning — importance_sampling.py:153
np.where(np.nonzero(p)[0], 1, 0).sum() counts indices, not entries (off-by-one when index 0 is nonzero). Use np.count_nonzero(p). Pre-existing, warning-string only.

Cleanup (minor, optional)

  • Dead state introduced: current_elbo (assigned, never read, in the hot callback), with_warnings (defined, never called), best_state["win_idx"] (never consumed), MultiPathfinderResult.num_paths/num_draws fields (always Noneinclude_if no-ops at idata.py:107), and four redundant np.asarray calls on already-ndarray outputs (single_path.py:210-213).
  • LBFGSConfig destructured then passed positionally (single_path.py:97-102,154) — the config object's benefit is undone; let LBFGS accept the config.
  • Reuse: get_param_coords (idata.py:32) duplicates make_unpacked_variable_names from the sibling laplace_approx/idata.py (and produces worse, dimless labels); _default_cores (multipath.py:208) re-derives the min(cores, paths) clamp that setup_cores_blas_cores already returns.
  • Fragility: from_path_results builds the result via positional cls(*[...]) keyed on NUMERIC_ATTRIBUTES order matching the first six dataclass fields — enforced nowhere; a field reorder silently misassigns arrays. Prefer keyword construction.

@ricardoV94 ricardoV94 changed the title Pathfinder refactor Speedup Pathfinder compilation and wall-time May 31, 2026
@ricardoV94
Copy link
Copy Markdown
Member

The failure is fixed by pymc-devs/pytensor#2191

@jessegrabowski
Copy link
Copy Markdown
Member Author

cut a release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants