refactor(budget-optimizer): BudgetOptimizer accepts pm.Model + DataTree directly#2654
refactor(budget-optimizer): BudgetOptimizer accepts pm.Model + DataTree directly#2654williambdean wants to merge 1 commit into
Conversation
…ee directly
Replaces the wrapper class requirement with direct pm.Model + DataTree
inputs. Adds configurable variable/dimension names, standalone merge
utilities, and a model_validator for backward compatibility.
- BudgetOptimizer: model + idata fields replace mmm_model wrapper
- New fields: adstock_periods, channel_scales, channel_data_var,
channel_contribution_var, date_dim, mu_effects (Sequence)
- model_validator('before') unpacks legacy wrappers transparently
- New: merge_inference_data() and merge_models_and_idata() functions
- New: MMM.create_optimization_model(start_date, end_date)
- New: MMM.budget_optimizer(start_date, end_date, **kwargs)
- Protocol: OptimizerCompatibleModelWrapper -> OptimizerCompatibleModel
with optimization_model() method
- Deprecated: BuildMergedModel, BudgetOptimizerWrapper, CustomModelWrapper
- Type improvements: Sequence instead of list, DataTree typing
- Test coverage: 105 passed, 2 new API tests
Closes #2425
2f54c81 to
ca731a3
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## v1.0.0 #2654 +/- ##
==========================================
- Coverage 94.06% 93.52% -0.55%
==========================================
Files 97 97
Lines 14629 14826 +197
==========================================
+ Hits 13761 13866 +105
- Misses 868 960 +92 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
Please review @daimon-pymclabs |
daimon-pymclabs
left a comment
There was a problem hiding this comment.
Review
Overall this is a solid, well-structured refactor and the direction is right: decoupling BudgetOptimizer from the MMM wrapper so it accepts a plain pm.Model + DataTree, making the variable/dimension names configurable (channel_data_var, date_dim, channel_contribution_var), and keeping every legacy path alive through a transparent model_validator + deprecation warnings. The configurable names and the merge_models_and_idata convenience are genuine flexibility/DX wins — a non-MMM PyMC model can now be optimized without subclassing, and the multi-model merge workflow is a single call. Docstrings and examples are excellent.
A few things to address before merge, roughly in order of importance.
1. frozen_deterministics is silently dropped — regression for HSGP / time-varying models
BudgetOptimizer.extract_response_distribution previously passed frozen_deterministics=getattr(self.mmm_model, "frozen_deterministics", None). The refactor drops that argument entirely:
return extract_response_distribution(
pymc_model=self._pymc_model,
idata=_extract_dataset(self.idata, "posterior"),
response_variable=response_variable,
)But extract_response_distribution still accepts frozen_deterministics, and its docstring is explicit: "Some models (e.g. those containing HSGP) need this to obtain a valid conditional posterior graph." MMM.frozen_deterministics also still exists and is used elsewhere in mmm.py. The new mmm.budget_optimizer() entry point doesn't thread it through either. So for any MMM with frozen deterministics (HSGP, time-varying intercept/media), the optimizer will now build the response graph without freezing them — risking either a RuntimeError("RVs found in the extracted graph") or silently different optimization results. Suggest adding a frozen_deterministics: list[str] | None field on BudgetOptimizer and having mmm.budget_optimizer() pass self.frozen_deterministics.
2. mu_effects budget slots are half-wired — allocate_budget will break once an optimizable effect is present
model_post_init enlarges the optimization vector to size_budgets + size_mu_effect_budgets and slices out _effect_budget_xtensors. Two gaps:
_effect_budget_xtensorsis built but never consumed anywhere in the file (noreplace_for_optimizationcall, not referenced in the objective graph) — currently dead scaffolding.allocate_budgetsizesx0andboundsfrombudgets_size = self.budgets_to_optimize.sum()(i.e.size_budgetsonly), while the compiled objective expects_budgets_flatof lengthsize_budgets + size_mu_effect_budgets. Soself._budgets_flat.type.filter(x0)will raise on a shape mismatch (andboundswould be the wrong length) the moment an effect withreplace_for_optimizationis passed.
This is dormant today because stock mu_effects lack replace_for_optimization, so _optimizable_mu_effects is empty — but it's untested, fragile, and will bite #2621. Either fully wire it (consume the xtensors + extend x0/bounds/per-effect bounds) or gate it behind a guard that raises a clear "not yet supported" error until #2621 lands.
3. New merge utilities have no unit tests
Codecov flags 59% patch coverage / ~95 missing lines in budget_optimizer.py, and merge_inference_data / merge_models_and_idata have non-trivial logic (variable/dim prefixing, shared_dims derivation, draw thinning, xr.merge alignment) but no direct tests in the diff. The xr.merge step in particular aligns on shared chain/draw indexes across independently-sampled models — worth a test asserting the merged shapes/coords are what you expect (and that thinning + prefixing round-trip correctly), since silent outer-join NaNs there would be hard to catch downstream.
4. idata argument to extract_response_distribution changed type
It now receives _extract_dataset(self.idata, "posterior") (a bare xr.Dataset) instead of the full DataTree/InferenceData it got before. extract_response_distribution runs az.extract(idata). Please confirm az.extract behaves identically on a bare posterior Dataset vs the previous container across the supported arviz range — az.extract historically expects an InferenceData/DataTree with a posterior group and a group= arg, so passing the inner Dataset may rely on convert-to-dataset fallthrough.
5. Minor
merge_models_and_idatadocstring/behavior mismatch: the docstring (andRaises) say it requires ≥2 models, but the body only checkslen(models) == len(idatas). Either enforcelen(models) >= 2or correct the docstring."optimization_model" in type(model).__dict__breaks inheritance (used in the validator and inBuildMergedModel.optimization_model): a subclass that inheritsoptimization_modelwithout redefining it won't be found intype(model).__dict__and will fall back to the deprecated_set_predictors_for_optimization. Consider walking the MRO or using a sentinel that distinguishes "real impl" from the Protocol stub.- Redundant work in
mmm.budget_optimizer:create_zero_dataset(...)runs once insidecreate_optimization_model()and again directly just to computenum_dates. You can derivenum_periodsfrom the built model'sdatecoord instead of rebuilding the zero dataset. - DRY: the
idata→DataTree.from_dict({g: getattr(idata, g) ...})conversion appears three times in_handle_legacy_model_arg; pull it into a small helper (it already overlaps with_extract_dataset).
Happy to dig into any of these further — #1 and #2 are the only two I'd consider blocking.
Closes #2425.
BudgetOptimizer no longer requires a wrapper object. Pass a pre-built
pm.Modeland itsDataTreedirectly. All existing wrapper pathscontinue to work via a transparent backward-compat validator.
Configurable variable and dimension names (
channel_data_var,date_dim, etc.) replace the old hardcoded strings. Amu_effectsfield is forward-compatible with future OptimizableMuEffect PRs (#2621).
New standalone utilities for multi-model merging:
BuildMergedModel,BudgetOptimizerWrapper, andCustomModelWrappernow emit
DeprecationWarningon instantiation.📚 Documentation preview 📚: https://pymc-marketing--2654.org.readthedocs.build/en/2654/