Skip to content

Commit d5baa46

Browse files
santoso-wijayaThe Meridian Authors
authored andcommitted
[refactor] Move beta-x computations to equations.py
PiperOrigin-RevId: 841954697
1 parent fca4cdb commit d5baa46

File tree

4 files changed

+577
-144
lines changed

4 files changed

+577
-144
lines changed

meridian/model/equations.py

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, model_context: context.ModelContext):
4242

4343
def adstock_hill_media(
4444
self,
45+
*,
4546
media: backend.Tensor,
4647
alpha: backend.Tensor,
4748
ec: backend.Tensor,
@@ -102,6 +103,7 @@ def adstock_hill_media(
102103

103104
def adstock_hill_rf(
104105
self,
106+
*,
105107
reach: backend.Tensor,
106108
frequency: backend.Tensor,
107109
alpha: backend.Tensor,
@@ -245,28 +247,172 @@ def compute_non_media_treatments_baseline(
245247

246248
def linear_predictor_counterfactual_difference_media(
247249
self,
250+
*,
248251
media_transformed: backend.Tensor,
249252
alpha_m: backend.Tensor,
250253
ec_m: backend.Tensor,
251254
slope_m: backend.Tensor,
252255
) -> backend.Tensor:
253-
raise NotImplementedError
256+
"""Calculates linear predictor counterfactual difference for non-RF media.
257+
258+
For non-RF media variables (paid or organic), this function calculates the
259+
linear predictor difference between the treatment variable and its
260+
counterfactual. "Linear predictor" refers to the output of the hill/adstock
261+
function, which is multiplied by the geo-level coefficient.
262+
263+
This function does the calculation efficiently by only calculating calling
264+
the hill/adstock function if the prior counterfactual is not all zeros.
265+
266+
Args:
267+
media_transformed: The output of the hill/adstock function for actual
268+
historical media data.
269+
alpha_m: The adstock alpha parameter values.
270+
ec_m: The adstock ec parameter values.
271+
slope_m: The adstock hill slope parameter values.
272+
273+
Returns:
274+
The linear predictor difference between the treatment variable and its
275+
counterfactual.
276+
"""
277+
if self._context.media_tensors.prior_media_scaled_counterfactual is None:
278+
return media_transformed
279+
media_transformed_counterfactual = self.adstock_hill_media(
280+
media=self._context.media_tensors.prior_media_scaled_counterfactual,
281+
alpha=alpha_m,
282+
ec=ec_m,
283+
slope=slope_m,
284+
decay_functions=self._context.adstock_decay_spec.media,
285+
)
286+
# Absolute values is needed because the difference is negative for mROI
287+
# priors and positive for ROI and contribution priors.
288+
return backend.absolute(
289+
media_transformed - media_transformed_counterfactual
290+
)
254291

255292
def linear_predictor_counterfactual_difference_rf(
256293
self,
294+
*,
257295
rf_transformed: backend.Tensor,
258296
alpha_rf: backend.Tensor,
259297
ec_rf: backend.Tensor,
260298
slope_rf: backend.Tensor,
261299
) -> backend.Tensor:
262-
raise NotImplementedError
300+
"""Calculates linear predictor counterfactual difference for RF media.
301+
302+
For RF media variables (paid or organic), this function calculates the
303+
linear predictor difference between the treatment variable and its
304+
counterfactual. "Linear predictor" refers to the output of the hill/adstock
305+
function, which is multiplied by the geo-level coefficient.
306+
307+
This function does the calculation efficiently by only calculating calling
308+
the hill/adstock function if the prior counterfactual is not all zeros.
309+
310+
Args:
311+
rf_transformed: The output of the hill/adstock function for actual
312+
historical media data.
313+
alpha_rf: The adstock alpha parameter values.
314+
ec_rf: The adstock ec parameter values.
315+
slope_rf: The adstock hill slope parameter values.
316+
317+
Returns:
318+
The linear predictor difference between the treatment variable and its
319+
counterfactual.
320+
"""
321+
if self._context.rf_tensors.prior_reach_scaled_counterfactual is None:
322+
return rf_transformed
323+
rf_transformed_counterfactual = self.adstock_hill_rf(
324+
reach=self._context.rf_tensors.prior_reach_scaled_counterfactual,
325+
frequency=self._context.rf_tensors.frequency,
326+
alpha=alpha_rf,
327+
ec=ec_rf,
328+
slope=slope_rf,
329+
decay_functions=self._context.adstock_decay_spec.rf,
330+
)
331+
# Absolute values is needed because the difference is negative for mROI
332+
# priors and positive for ROI and contribution priors.
333+
return backend.absolute(rf_transformed - rf_transformed_counterfactual)
263334

264335
def calculate_beta_x(
265336
self,
337+
*,
266338
is_non_media: bool,
267339
incremental_outcome_x: backend.Tensor,
268340
linear_predictor_counterfactual_difference: backend.Tensor,
269341
eta_x: backend.Tensor,
270342
beta_gx_dev: backend.Tensor,
271343
) -> backend.Tensor:
272-
raise NotImplementedError
344+
"""Calculates coefficient mean parameter for any treatment variable type.
345+
346+
The "beta_x" in the function name refers to the coefficient mean parameter
347+
of any treatment variable. The "x" can represent "m", "rf", "om", or "orf".
348+
This function can also be used to calculate "gamma_n" for any non-media
349+
treatments.
350+
351+
Args:
352+
is_non_media: Boolean indicating whether the treatment variable is a
353+
non-media treatment. This argument is used to determine whether the
354+
coefficient random effects are normal or log-normal. If `True`, then
355+
random effects are assumed to be normal. Otherwise, the distribution is
356+
inferred from `self._context.media_effects_dist`.
357+
incremental_outcome_x: The incremental outcome of the treatment variable,
358+
which depends on the parameter values of a particular prior or posterior
359+
draw. The "_x" indicates that this is a tensor with length equal to the
360+
dimension of the treatment variable.
361+
linear_predictor_counterfactual_difference: The difference between the
362+
treatment variable and its counterfactual on the linear predictor scale.
363+
"Linear predictor" refers to the quantity that is multiplied by the
364+
geo-level coefficient. For media variables, this is the output of the
365+
hill/adstock transformation function. For non-media treatments, this is
366+
simply the treatment variable after centering/scaling transformations.
367+
This tensor has dimensions for geo, time, and channel.
368+
eta_x: The random effect standard deviation parameter values. For media
369+
variables, the "x" represents "m", "rf", "om", or "orf". For non-media
370+
treatments, this argument should be set to `xi_n`, which is analogous to
371+
"eta".
372+
beta_gx_dev: The latent standard normal parameter values of the geo-level
373+
coefficients. For media variables, the "x" represents "m", "rf", "om",
374+
or "orf". For non-media treatments, this argument should be set to
375+
`gamma_gn_dev`, which is analogous to "beta_gx_dev".
376+
377+
Returns:
378+
The coefficient mean parameter of the treatment variable, which has
379+
dimension equal to the number of treatment channels..
380+
"""
381+
if is_non_media:
382+
random_effects_normal = True
383+
else:
384+
random_effects_normal = (
385+
self._context.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
386+
)
387+
if self._context.revenue_per_kpi is None:
388+
revenue_per_kpi = backend.ones(
389+
[self._context.n_geos, self._context.n_times], dtype=backend.float32
390+
)
391+
else:
392+
revenue_per_kpi = self._context.revenue_per_kpi
393+
incremental_outcome_gx_over_beta_gx = backend.einsum(
394+
"...gtx,gt,g,->...gx",
395+
linear_predictor_counterfactual_difference,
396+
revenue_per_kpi,
397+
self._context.population,
398+
self._context.kpi_transformer.population_scaled_stdev,
399+
)
400+
if random_effects_normal:
401+
numerator_term_x = backend.einsum(
402+
"...gx,...gx,...x->...x",
403+
incremental_outcome_gx_over_beta_gx,
404+
beta_gx_dev,
405+
eta_x,
406+
)
407+
denominator_term_x = backend.einsum(
408+
"...gx->...x", incremental_outcome_gx_over_beta_gx
409+
)
410+
return (incremental_outcome_x - numerator_term_x) / denominator_term_x
411+
# For log-normal random effects, beta_x and eta_x are not mean & std.
412+
# The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
413+
denominator_term_x = backend.einsum(
414+
"...gx,...gx->...x",
415+
incremental_outcome_gx_over_beta_gx,
416+
backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
417+
)
418+
return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)

0 commit comments

Comments
 (0)