Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 117 additions & 9 deletions meridian/analysis/review/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,31 @@ def run(self) -> results.PriorPosteriorShiftCheckResult:
# ==============================================================================
# Check: Implausible ROI
# ==============================================================================
def _calculate_spend_share(model_context: context.ModelContext) -> np.ndarray:
"""Calculates the spend share for all paid channels.

Args:
model_context: The ModelContext of the Meridian model.

Returns:
A 1D NumPy array of shape `(n_channels,)` containing the spend share for
each paid channel, or all zeros if the total spend sum is zero.
"""
initial_spend = model_context.input_data.get_total_spend()
# TODO: Verify if we really support 1D spend
spend = (
np.sum(initial_spend, axis=(0, 1))
if initial_spend.ndim == 3
else initial_spend
)

total_spend_sum = np.sum(spend)
if total_spend_sum > 0:
return spend / total_spend_sum
else:
return np.zeros_like(spend)


class ImplausibleROICheck(
BaseCheck[configs.ImplausibleROIConfig, results.ImplausibleROICheckResult]
):
Expand All @@ -829,14 +854,7 @@ class ImplausibleROICheck(
@override
def run(self) -> results.ImplausibleROICheckResult:
# 1. Get spend and calculate spend share
spend = self._model_context.input_data.get_total_spend()
if spend.ndim == 3:
spend = np.sum(spend, axis=(0, 1))

# Per-channel spend for paid channels is validated to be positive during
# ModelContext initialization, so total_spend_sum > 0.
total_spend_sum = np.sum(spend)
spend_share = spend / total_spend_sum
spend_share = _calculate_spend_share(self._model_context)

# 2. Get posterior ROI and channels
posterior_rois = []
Expand All @@ -862,7 +880,12 @@ def run(self) -> results.ImplausibleROICheckResult:
low_roi_channels = []

spend_weighted_roi_all = roi_means * spend_share
reciprocal_spend_weighted_roi_all = roi_means / spend_share
reciprocal_spend_weighted_roi_all = np.divide(
roi_means,
spend_share,
out=np.full_like(roi_means, np.nan),
where=(spend_share != 0),
)

for i, channel in enumerate(channels):
mean = roi_means[i]
Expand Down Expand Up @@ -918,3 +941,88 @@ def run(self) -> results.ImplausibleROICheckResult:
low_roi_channels=low_roi_channels,
aggregate_details=aggregate_details,
)


# ==============================================================================
# Check: High Variance ROI
# ==============================================================================
class HighVarianceCheck(
BaseCheck[configs.HighVarianceConfig, results.HighVarianceCheckResult]
):
"""A check for paid channels with high variance in posterior ROI."""

@override
def run(self) -> results.HighVarianceCheckResult:
# 1. Get spend and calculate spend share
spend_share = _calculate_spend_share(self._model_context)

# 2. Get posterior ROI and channels
posterior_rois = []
channels = []

if constants.MEDIA_CHANNEL in self._inference_data.posterior.coords:
posterior_rois.append(self._inference_data.posterior.roi_m.values)
channels.extend(
self._inference_data.posterior.media_channel.values.tolist()
)

if constants.RF_CHANNEL in self._inference_data.posterior.coords:
posterior_rois.append(self._inference_data.posterior.roi_rf.values)
channels.extend(self._inference_data.posterior.rf_channel.values.tolist())

if not posterior_rois:
raise ValueError("No posterior ROI data found in inference_data.")

posterior_roi_concat = np.concatenate(posterior_rois, axis=-1)
roi_medians = np.median(posterior_roi_concat, axis=(0, 1))

# 3. Compute credible intervals using az.hdi
hdi = az.hdi(posterior_roi_concat, hdi_prob=self._config.hdi_prob)
hdi_lower, hdi_upper = hdi.T

rel_width_post = np.divide(
hdi_upper - hdi_lower,
np.abs(roi_medians),
out=np.zeros_like(roi_medians, dtype=float),
where=(roi_medians != 0),
)

# 4. Compute high variance check
relative_width_ratio = np.divide(
rel_width_post,
self._config.prior_relative_hdi_width,
out=np.zeros_like(rel_width_post, dtype=float),
where=(self._config.prior_relative_hdi_width != 0),
)
spend_weighted_ratio = relative_width_ratio * spend_share

channel_results = []
high_variance_channels = []

for channel, share, ratio, weighted_ratio in zip(
channels, spend_share, relative_width_ratio, spend_weighted_ratio
):
if weighted_ratio > self._config.high_variance_threshold:
case = results.HighVarianceChannelCases.HIGH_VARIANCE
high_variance_channels.append(channel)
else:
case = results.HighVarianceChannelCases.ROI_PASS

channel_results.append(
results.HighVarianceChannelResult(
case=case,
channel_name=channel,
spend_share=share,
relative_width_ratio=ratio,
)
)

return results.HighVarianceCheckResult(
case=(
results.HighVarianceAggregateCases.REVIEW
if high_variance_channels
else results.HighVarianceAggregateCases.PASS
),
channel_results=channel_results,
high_variance_channels=high_variance_channels,
)
Loading
Loading