Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/gallery/gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ Welcome to the PyMC-Marketing example gallery! This gallery provides visual navi
:link: ../notebooks/mmm/mmm_cost_per_unit.html
:::

:::{grid-item-card} Comparison of Additive, Log, and Log-Log Models
:img-top: ../gallery/images/mmm_multiplicative.png
:link: ../notebooks/mmm/mmm_multiplicative.html
:::

:::{grid-item-card} Parameter Recovery
:img-top: ../gallery/images/mmm_data_generator.png
:link: ../notebooks/mmm/mmm_data_generator.html
Expand Down
3,744 changes: 3,744 additions & 0 deletions docs/source/notebooks/mmm/mmm_multiplicative.ipynb

Large diffs are not rendered by default.

159 changes: 146 additions & 13 deletions pymc_marketing/data/idata/mmm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Literal

import arviz as az
Expand Down Expand Up @@ -395,6 +396,12 @@ def get_avg_cost_per_unit(self) -> xr.DataArray:

# ==================== Contribution Access ====================

@property
def _link(self) -> str:
"""Detect the link function from idata attributes, defaulting to 'identity'."""
attrs = getattr(self.idata, "attrs", {})
return attrs.get("link", "identity")

def get_channel_contributions(self, original_scale: bool = True) -> xr.DataArray:
"""Get channel contribution posterior samples.

Expand All @@ -419,7 +426,6 @@ def get_channel_contributions(self, original_scale: bool = True) -> xr.DataArray
include_controls=False,
include_seasonality=False,
)
# Extract from Dataset - xarray preserves coordinate structure
return contributions["channels"]

def get_contributions(
Expand All @@ -431,6 +437,11 @@ def get_contributions(
) -> xr.Dataset:
"""Get all contribution variables in a single dataset.

For identity-link models, contributions are computed by multiplying
log-space values by ``target_scale``. For log-link models, a hybrid
decomposition is used: counterfactual total media lift with
proportional log-space channel shares.

Parameters
----------
original_scale : bool, default True
Expand All @@ -452,27 +463,48 @@ def get_contributions(
ValueError
If original_scale=True and target_scale is not found in constant_data
"""
contributions = {}
if self._link == "log" and original_scale:
if not include_controls or not include_seasonality:
warnings.warn(
"For log-link models with original_scale=True, "
"controls and seasonality are embedded in baseline and "
"cannot be separately toggled. "
"Arguments include_controls/include_seasonality are ignored.",
UserWarning,
stacklevel=2,
)
return self._get_contributions_log_link(
include_baseline=include_baseline,
)
return self._get_contributions_identity(
original_scale=original_scale,
include_baseline=include_baseline,
include_controls=include_controls,
include_seasonality=include_seasonality,
)

def _get_contributions_identity(
self,
original_scale: bool = True,
include_baseline: bool = True,
include_controls: bool = True,
include_seasonality: bool = True,
) -> xr.Dataset:
"""Additive decomposition for identity-link models."""
contributions: dict[str, xr.DataArray] = {}

# Channel contributions
# Channels variables - use "channels" (plural) as key to avoid xarray
# dimension/key name conflict (a key matching a dimension name gets
# promoted to a coordinate instead of staying as a data variable)
if original_scale:
if "channel_contribution_original_scale" in self.idata.posterior:
contributions["channels"] = (
self.idata.posterior.channel_contribution_original_scale
)
else:
# Compute on-the-fly
channel_contrib = self.idata.posterior.channel_contribution
target_scale = self.get_target_scale()
# xarray automatically handles broadcasting when dimensions match
contributions["channels"] = channel_contrib * target_scale
else:
contributions["channels"] = self.idata.posterior.channel_contribution

# Baseline/intercept
if include_baseline:
for var in ["intercept_contribution", "intercept_baseline"]:
if var in self.idata.posterior:
Expand All @@ -484,9 +516,6 @@ def get_contributions(
contributions["baseline"] = baseline
break

# Control variables - use "controls" (plural) as key to avoid xarray
# dimension/key name conflict (a key matching a dimension name gets
# promoted to a coordinate instead of staying as a data variable)
if include_controls and "control_contribution" in self.idata.posterior:
control = self.idata.posterior.control_contribution
if original_scale:
Expand All @@ -500,7 +529,6 @@ def get_contributions(
else:
contributions["controls"] = control

# Seasonality
if (
include_seasonality
and "yearly_seasonality_contribution" in self.idata.posterior
Expand All @@ -522,6 +550,111 @@ def get_contributions(

return xr.Dataset(contributions)

def _get_contributions_log_link(
self,
include_baseline: bool = True,
) -> xr.Dataset:
r"""Hybrid decomposition for log-link models (always original-scale).

For log-link (multiplicative) models the linear predictor lives in
log-space:

.. math::

\mu = \text{intercept} + \sum_c \text{channel}_c
+ \text{controls} + \text{seasonality}

so :math:`y = \exp(\mu) \times \text{target\_scale}`.

Because the components combine multiplicatively, individual control
and seasonality effects **cannot** be isolated in original scale
without a full counterfactual for each. They are therefore folded
into the ``baseline`` component, which represents the predicted
target with all media set to zero:

.. math::

\text{baseline} = \exp(\mu - \text{media\_total\_log})
\times \text{target\_scale}

Per-channel contributions are obtained by distributing the total
media lift (``y_hat - baseline``) proportionally to each channel's
share of the total log-space media contribution:

.. math::

\text{channel}_c = \text{total\_media\_lift}
\times \frac{\text{channel\_contrib}_c}
{\sum_c \text{channel\_contrib}_c}

This uses the same log-link prediction transform as
:meth:`~pymc_marketing.mmm.mmm.MMM.compute_counterfactual_contributions_dataset`
via shared decomposition helpers in :mod:`pymc_marketing.mmm.decomposition`.

Parameters
----------
include_baseline : bool, default True
Whether to include the ``baseline`` component (all non-media
effects) in the returned dataset.

Returns
-------
xr.Dataset
Dataset always containing:

- ``channels`` : per-channel contributions in original scale
with dims ``(chain, draw, date, channel)`` (plus any custom
dims).

If *include_baseline* is True, also contains:

- ``baseline`` : non-media prediction in original scale with
dims ``(chain, draw, date)`` (plus any custom dims).

Notes
-----
Unlike :meth:`_get_contributions_identity`, this method does **not**
return separate ``controls`` or ``seasonality`` keys. Those effects
are embedded in ``baseline`` and cannot be additively separated
without additional counterfactual evaluations.

The sum ``channels.sum("channel") + baseline`` equals
``exp(mu) * target_scale`` (the full posterior prediction) for every
posterior draw.
"""
# Deferred import: avoid circular import (mmm.py -> mmm_wrapper -> mmm pkg).
from pymc_marketing.mmm.decomposition import (
original_scale_prediction_from_mu,
safe_proportional_share,
)

posterior = self.idata.posterior
target_scale = self.get_target_scale()

mu_total = posterior["mu"]
channel_contrib = posterior["channel_contribution"]
media_total_log = channel_contrib.sum(dim="channel")

y_hat = original_scale_prediction_from_mu(mu_total, target_scale)
y_hat_no_media = original_scale_prediction_from_mu(
mu_total - media_total_log, target_scale
)
total_media_lift = y_hat - y_hat_no_media

per_channel_shares = safe_proportional_share(
numerator=channel_contrib,
denominator=media_total_log,
)

contributions: dict[str, xr.DataArray] = {
"channels": total_media_lift * per_channel_shares,
}

if include_baseline:
contributions["baseline"] = y_hat_no_media

return xr.Dataset(contributions)

def get_elementwise_roas(self, original_scale: bool = True) -> xr.DataArray:
"""Compute element-wise ROAS (Return on Ad Spend) for each channel.

Expand Down
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
HillSaturationSigmoid,
InverseScaledLogisticSaturation,
LogisticSaturation,
LogSaturation,
MichaelisMentenSaturation,
NoSaturation,
RootSaturation,
Expand Down Expand Up @@ -96,6 +97,7 @@
"HillSaturationSigmoid",
"InverseScaledLogisticSaturation",
"LinearTrend",
"LogSaturation",
"LogisticSaturation",
"MMMBuilder",
"MediaConfig",
Expand Down
37 changes: 37 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def function(self, x, b):
from pymc_extras.deserialize import deserialize
from pymc_extras.prior import Prior
from pytensor.xtensor import as_xtensor
from pytensor.xtensor import math as ptxm

from pymc_marketing.mmm.components.base import (
Transformation,
Expand Down Expand Up @@ -477,6 +478,41 @@ def function(self, x, alpha, beta, *, dim: str | None = None):
}


@serialization.register
class LogSaturation(SaturationTransformation):
r"""Logarithmic saturation for log-log models.

Applies :math:`\beta \, \log(1 + x)`, mapping spend through a concave
logarithmic curve with diminishing returns. When combined with
``link="log"`` in the MMM, the coefficient :math:`\beta` has an
elasticity-like interpretation.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import LogSaturation

rng = np.random.default_rng(0)

saturation = LogSaturation()
prior = saturation.sample_prior(random_seed=rng)
curve = saturation.sample_curve(prior)
saturation.plot_curve(curve, random_seed=rng)
plt.show()

"""

def function(self, x, beta, *, dim: str | None = None):
"""Logarithmic saturation function: beta * log(1 + x)."""
x = as_xtensor(x)
beta = as_xtensor(beta)
return beta * ptxm.log1p(x)

default_priors = {"beta": Prior("HalfNormal", sigma=1)}
Comment thread
juanitorduz marked this conversation as resolved.


@serialization.register
class NoSaturation(SaturationTransformation):
"""Wrapper around linear saturation function.
Expand Down Expand Up @@ -517,6 +553,7 @@ def function(self, x, beta, *, dim: str | None = None):
"hill": HillSaturation,
"hill_sigmoid": HillSaturationSigmoid,
"root": RootSaturation,
"log_saturation": LogSaturation,
"no_saturation": NoSaturation,
}

Expand Down
56 changes: 56 additions & 0 deletions pymc_marketing/mmm/decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 - 2026 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared decomposition math helpers for MMM counterfactuals."""

from __future__ import annotations

import numpy as np
import xarray as xr


def original_scale_prediction_from_mu(
mu: xr.DataArray,
target_scale: xr.DataArray,
) -> xr.DataArray:
"""Convert log-link linear predictor samples to original scale."""
return np.exp(mu) * target_scale


def log_counterfactual_remove_component(
mu_total: xr.DataArray,
component: xr.DataArray,
target_scale: xr.DataArray,
) -> xr.DataArray:
"""Per-draw log-link counterfactual contribution for a component."""
y_hat = original_scale_prediction_from_mu(mu_total, target_scale)
y_hat_without_component = np.exp(mu_total - component) * target_scale
return y_hat - y_hat_without_component


def identity_counterfactual_component(
component: xr.DataArray,
target_scale: xr.DataArray,
) -> xr.DataArray:
"""Per-draw identity-link contribution for a component."""
return component * target_scale


def safe_proportional_share(
numerator: xr.DataArray,
denominator: xr.DataArray,
) -> xr.DataArray:
"""Compute a finite proportional share, guarding against zero denominators."""
denom_safe = xr.where(np.abs(denominator) > 1e-12, denominator, np.nan)
share = numerator / denom_safe
return share.fillna(0.0).where(np.isfinite(share), 0.0)
Loading