Skip to content

Commit 6bd60b8

Browse files
authored
Support fixed/user-defined scaling factors for stable model refreshes (#2479)
* Support fixed/user-defined scaling factors for stable model refreshes (#2478) Add method="fixed" to VariableScaling so users can pin scaling constants that remain stable across production model refreshes, eliminating scale-induced prior drift when data distributions change. Made-with: Cursor * Address review feedback: serialization, dict-valued scaling, key validation - Fix legacy MMM serialization round-trip: use model_dump() in create_idata_attrs and pass value through in _deserialize_scaling so fixed scaling survives save/load. - Support dict-valued fixed channel scaling in legacy MMM (per-channel constants keyed by column name); reject dict-valued target with a clear error since the 1-D MMM has a single target. - Validate dict keys against coordinate labels in multidimensional MMM: report missing and unexpected keys explicitly instead of late KeyError or silent ignore. - Add regression tests for all three fixes. Made-with: Cursor * Improve patch coverage: add tests, remove dead code - Remove dead `scaling is None` branch in create_idata_attrs (scaling is always set in __init__) - Replace isinstance guard with cast() for mypy narrowing - Add test for dict-valued scaling with wrong remaining dims count - Add tests for VariableScaling dims validation (date, duplicates) Made-with: Cursor * Restore legacy MMM scaling type check This preserves the legacy TypeError contract for non-fixed scaling so the shard 5 edge-case test passes after the fixed-scaling refactor. Made-with: Cursor * Refactor VariableScaling into a two-class hierarchy Split the single VariableScaling class into an abstract base with two concrete subclasses (DataDerivedScaling and FixedScaling) to eliminate the awkward optional value field and its cross-field validation, per code review feedback. - VariableScaling(ABC) keeps dims + _validate_dims - DataDerivedScaling has method: Literal["max", "mean"] - FixedScaling has value: float | dict[str, float] (non-optional) - Scaling.to_dict/from_dict use the __type__ registry for polymorphic serialization; legacy format (no __type__) still loads correctly - Updated legacy MMM and multidimensional MMM to use isinstance checks and the new concrete classes - All 844 tests pass across test_scaling, test_mmm, test_multidimensional Made-with: Cursor * Extract _build_fixed_scale to eliminate multiple returns Address review comment by refactoring _compute_scale_for_variable into two methods: the main method now has a single return, dispatching to a dedicated _build_fixed_scale helper for FixedScaling logic. Made-with: Cursor * Validate dict-valued FixedScaling keys at init, not build_model Move the missing/extra key checks for dict-valued channel FixedScaling from build_model() to MMM.__init__() in both legacy and multidimensional implementations, so users get immediate feedback on key mismatches. Made-with: Cursor * feat(mmm): multidimensional fixed scaling with DataArray and init fix - Validate dict channel keys at init only when remaining dim is channel - Support FixedScaling.value as xarray.DataArray (broadcast to reduced grid) - JSON-safe DataArray serialization; from_long_dataframe helper - Clearer dict error when multiple remaining dims; target dict keys at build - Tests: broadcast, full grids, 3D panel, misaligned coords, idata round-trip - Doc examples for panel fixed scaling Made-with: Cursor * fix(test): add fastprogress to test extras for BlackJAX mlflow autolog BlackJAX sampling uses jax.io_callback for progress bars, which requires fastprogress. CI installs only .[test]; without it, test_autolog_pymc_model fails with CpuCallback errors on Linux runners. Made-with: Cursor * revert: drop fastprogress from test extras (unrelated to fixed scaling PR) Restore prior test dependency set; BlackJAX/mlflow env gap can be addressed separately on main. Made-with: Cursor
1 parent 7dfb966 commit 6bd60b8

7 files changed

Lines changed: 1546 additions & 125 deletions

File tree

pymc_marketing/mmm/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
preprocessing_method_X,
6464
preprocessing_method_y,
6565
)
66+
from pymc_marketing.mmm.scaling import (
67+
DataDerivedScaling,
68+
FixedScaling,
69+
Scaling,
70+
VariableScaling,
71+
)
6672
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
6773
from pymc_marketing.mmm.time_slice_cross_validation import (
6874
TimeSliceCrossValidationResult,
@@ -78,8 +84,10 @@
7884
"BaseValidateMMM",
7985
"BinomialAdstock",
8086
"CovFunc",
87+
"DataDerivedScaling",
8188
"DelayedAdstock",
8289
"FancyLinearRegression",
90+
"FixedScaling",
8391
"GeometricAdstock",
8492
"HSGPPeriodic",
8593
"HillSaturation",
@@ -99,12 +107,14 @@
99107
"PeriodicCovFunc",
100108
"RootSaturation",
101109
"SaturationTransformation",
110+
"Scaling",
102111
"SensitivityAnalysis",
103112
"SoftPlusHSGP",
104113
"TanhSaturation",
105114
"TanhSaturationBaselined",
106115
"TimeSliceCrossValidationResult",
107116
"TimeSliceCrossValidator",
117+
"VariableScaling",
108118
"WeeklyFourier",
109119
"WeibullCDFAdstock",
110120
"WeibullPDFAdstock",

pymc_marketing/mmm/mmm.py

Lines changed: 88 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import warnings
1919
from collections.abc import Sequence
20-
from typing import Annotated, Any, Literal
20+
from typing import Annotated, Any, Literal, cast
2121

2222
import arviz as az
2323
import matplotlib.pyplot as plt
@@ -54,7 +54,13 @@
5454
scale_lift_measurements,
5555
)
5656
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
57-
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
57+
from pymc_marketing.mmm.scaling import (
58+
DataDerivedScaling,
59+
FixedScaling,
60+
Scaling,
61+
deserialize_variable_scaling,
62+
validate_fixed_scaling_keys,
63+
)
5864
from pymc_marketing.mmm.tvp import create_time_varying_gp_multiplier, infer_time_index
5965
from pymc_marketing.mmm.utility import UtilityFunctionType, average_response
6066
from pymc_marketing.mmm.utils import (
@@ -205,22 +211,32 @@ def __init__(
205211
self.validate_data = validate_data
206212
self.adstock_first = adstock_first
207213

208-
# Initialize scaling configuration similar to multidimensional MMM
209214
if isinstance(scaling, dict):
210215
scaling = scaling.copy()
211216

212217
if "channel" not in scaling:
213-
scaling["channel"] = VariableScaling(method="max", dims=())
218+
scaling["channel"] = DataDerivedScaling(method="max", dims=())
214219
if "target" not in scaling:
215-
scaling["target"] = VariableScaling(method="max", dims=())
220+
scaling["target"] = DataDerivedScaling(method="max", dims=())
216221

217222
scaling = Scaling(**scaling)
218223

219224
self.scaling: Scaling = scaling or Scaling(
220-
target=VariableScaling(method="max", dims=()),
221-
channel=VariableScaling(method="max", dims=()),
225+
target=DataDerivedScaling(method="max", dims=()),
226+
channel=DataDerivedScaling(method="max", dims=()),
222227
)
223228

229+
validate_fixed_scaling_keys(self.scaling.channel, channel_columns, "channel")
230+
231+
if isinstance(self.scaling.target, FixedScaling) and isinstance(
232+
self.scaling.target.value, dict
233+
):
234+
raise ValueError(
235+
"Dict-valued fixed target scaling is not supported in the "
236+
"legacy MMM (single target). Use a scalar value or switch "
237+
"to the multidimensional MMM."
238+
)
239+
224240
model_config = model_config or {}
225241
model_config = parse_model_config(
226242
model_config, # type: ignore
@@ -450,24 +466,59 @@ def _compute_scale_for_data(
450466

451467
def _compute_scales(self) -> None:
452468
"""Compute and save scaling factors for channels and target."""
453-
# Get raw data
454-
X_data = self.preprocessed_data["X"]
455-
if not isinstance(X_data, pd.DataFrame):
456-
raise TypeError("X data must be a DataFrame for scaling computation")
457-
458-
# Use pandas/numpy efficient operations - avoid redundant .values call
459-
channel_data = X_data[self.channel_columns].to_numpy()
460-
target_data = np.atleast_1d(np.asarray(self.preprocessed_data["y"]))
469+
channel_scaling = self.scaling.channel
470+
target_scaling = self.scaling.target
471+
472+
channel_scale: np.ndarray | float
473+
if isinstance(channel_scaling, FixedScaling):
474+
if isinstance(channel_scaling.value, dict):
475+
channel_scale = np.array(
476+
[channel_scaling.value[c] for c in self.channel_columns],
477+
dtype=float,
478+
)
479+
elif isinstance(channel_scaling.value, DataArray):
480+
raise ValueError(
481+
"DataArray-valued FixedScaling is not supported by the "
482+
"legacy MMM. Use a scalar or dict value, or switch to "
483+
"MultidimensionalMMM."
484+
)
485+
else:
486+
n_channels = len(self.channel_columns)
487+
channel_scale = np.full(n_channels, channel_scaling.value)
488+
else:
489+
X_data = self.preprocessed_data["X"]
490+
if not isinstance(X_data, pd.DataFrame):
491+
raise TypeError("X data must be a DataFrame for scaling computation")
461492

462-
# Compute scales based on scaling configuration
463-
self.channel_scale = self._compute_scale_for_data(
464-
channel_data, self.scaling.channel.method, axis=0
465-
)
466-
target_scale = self._compute_scale_for_data(
467-
target_data, self.scaling.target.method, axis=None
468-
)
469-
# Ensure target_scale is a Python float (convert from numpy scalar if needed)
470-
self.target_scale = float(target_scale)
493+
X_data = cast(pd.DataFrame, X_data)
494+
channel_data = X_data[self.channel_columns].to_numpy()
495+
channel_scale = self._compute_scale_for_data(
496+
channel_data, channel_scaling.method, axis=0
497+
)
498+
self.channel_scale = channel_scale
499+
500+
target_scale: float
501+
if isinstance(target_scaling, FixedScaling):
502+
if isinstance(target_scaling.value, DataArray):
503+
raise ValueError(
504+
"DataArray-valued FixedScaling is not supported by the "
505+
"legacy MMM. Use a scalar or dict value, or switch to "
506+
"MultidimensionalMMM."
507+
)
508+
if not isinstance(target_scaling.value, (int, float)):
509+
raise TypeError(
510+
f"Expected scalar FixedScaling value for target, "
511+
f"got {type(target_scaling.value).__name__}."
512+
)
513+
target_scale = float(target_scaling.value)
514+
else:
515+
target_data = np.atleast_1d(np.asarray(self.preprocessed_data["y"]))
516+
target_scale = float(
517+
self._compute_scale_for_data(
518+
target_data, target_scaling.method, axis=None
519+
)
520+
)
521+
self.target_scale = target_scale
471522

472523
def create_idata_attrs(self) -> dict[str, str]:
473524
"""Create attributes for the inference data.
@@ -493,22 +544,9 @@ def create_idata_attrs(self) -> dict[str, str]:
493544
attrs["treatment_nodes"] = json.dumps(self.treatment_nodes)
494545
attrs["outcome_node"] = json.dumps(self.outcome_node)
495546

496-
# Serialize scaling configuration
497-
if hasattr(self, "scaling") and self.scaling is not None:
498-
attrs["scaling"] = json.dumps(
499-
{
500-
"target": {
501-
"method": self.scaling.target.method,
502-
"dims": self.scaling.target.dims,
503-
},
504-
"channel": {
505-
"method": self.scaling.channel.method,
506-
"dims": self.scaling.channel.dims,
507-
},
508-
}
509-
)
510-
else:
511-
attrs["scaling"] = json.dumps(None)
547+
from pymc_marketing.serialization import serialization as _serialization
548+
549+
attrs["scaling"] = json.dumps(_serialization.serialize(self.scaling))
512550

513551
return attrs
514552

@@ -1254,6 +1292,9 @@ def _data_setter(
12541292
def _deserialize_scaling(cls, scaling_dict: dict | None) -> Scaling | None:
12551293
"""Deserialize scaling configuration from JSON.
12561294
1295+
Handles both new format (with ``__type__`` keys from the serialization
1296+
registry) and legacy format (flat dicts with ``method``/``dims``).
1297+
12571298
Parameters
12581299
----------
12591300
scaling_dict : dict | None
@@ -1267,15 +1308,14 @@ def _deserialize_scaling(cls, scaling_dict: dict | None) -> Scaling | None:
12671308
if scaling_dict is None:
12681309
return None
12691310

1311+
if "__type__" in scaling_dict:
1312+
from pymc_marketing.serialization import serialization as _serialization
1313+
1314+
return _serialization.deserialize(scaling_dict)
1315+
12701316
return Scaling(
1271-
target=VariableScaling(
1272-
method=scaling_dict["target"]["method"],
1273-
dims=tuple(scaling_dict["target"]["dims"]),
1274-
),
1275-
channel=VariableScaling(
1276-
method=scaling_dict["channel"]["method"],
1277-
dims=tuple(scaling_dict["channel"]["dims"]),
1278-
),
1317+
target=deserialize_variable_scaling(scaling_dict["target"]),
1318+
channel=deserialize_variable_scaling(scaling_dict["channel"]),
12791319
)
12801320

12811321
@classmethod

0 commit comments

Comments
 (0)