@@ -42,6 +42,7 @@ def __init__(self, model_context: context.ModelContext):
4242
4343 def adstock_hill_media (
4444 self ,
45+ * ,
4546 media : backend .Tensor ,
4647 alpha : backend .Tensor ,
4748 ec : backend .Tensor ,
@@ -102,6 +103,7 @@ def adstock_hill_media(
102103
103104 def adstock_hill_rf (
104105 self ,
106+ * ,
105107 reach : backend .Tensor ,
106108 frequency : backend .Tensor ,
107109 alpha : backend .Tensor ,
@@ -245,28 +247,172 @@ def compute_non_media_treatments_baseline(
245247
246248 def linear_predictor_counterfactual_difference_media (
247249 self ,
250+ * ,
248251 media_transformed : backend .Tensor ,
249252 alpha_m : backend .Tensor ,
250253 ec_m : backend .Tensor ,
251254 slope_m : backend .Tensor ,
252255 ) -> backend .Tensor :
253- raise NotImplementedError
256+ """Calculates linear predictor counterfactual difference for non-RF media.
257+
258+ For non-RF media variables (paid or organic), this function calculates the
259+ linear predictor difference between the treatment variable and its
260+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
261+ function, which is multiplied by the geo-level coefficient.
262+
263+ This function does the calculation efficiently by only calculating calling
264+ the hill/adstock function if the prior counterfactual is not all zeros.
265+
266+ Args:
267+ media_transformed: The output of the hill/adstock function for actual
268+ historical media data.
269+ alpha_m: The adstock alpha parameter values.
270+ ec_m: The adstock ec parameter values.
271+ slope_m: The adstock hill slope parameter values.
272+
273+ Returns:
274+ The linear predictor difference between the treatment variable and its
275+ counterfactual.
276+ """
277+ if self ._context .media_tensors .prior_media_scaled_counterfactual is None :
278+ return media_transformed
279+ media_transformed_counterfactual = self .adstock_hill_media (
280+ media = self ._context .media_tensors .prior_media_scaled_counterfactual ,
281+ alpha = alpha_m ,
282+ ec = ec_m ,
283+ slope = slope_m ,
284+ decay_functions = self ._context .adstock_decay_spec .media ,
285+ )
286+ # Absolute values is needed because the difference is negative for mROI
287+ # priors and positive for ROI and contribution priors.
288+ return backend .absolute (
289+ media_transformed - media_transformed_counterfactual
290+ )
254291
255292 def linear_predictor_counterfactual_difference_rf (
256293 self ,
294+ * ,
257295 rf_transformed : backend .Tensor ,
258296 alpha_rf : backend .Tensor ,
259297 ec_rf : backend .Tensor ,
260298 slope_rf : backend .Tensor ,
261299 ) -> backend .Tensor :
262- raise NotImplementedError
300+ """Calculates linear predictor counterfactual difference for RF media.
301+
302+ For RF media variables (paid or organic), this function calculates the
303+ linear predictor difference between the treatment variable and its
304+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
305+ function, which is multiplied by the geo-level coefficient.
306+
307+ This function does the calculation efficiently by only calculating calling
308+ the hill/adstock function if the prior counterfactual is not all zeros.
309+
310+ Args:
311+ rf_transformed: The output of the hill/adstock function for actual
312+ historical media data.
313+ alpha_rf: The adstock alpha parameter values.
314+ ec_rf: The adstock ec parameter values.
315+ slope_rf: The adstock hill slope parameter values.
316+
317+ Returns:
318+ The linear predictor difference between the treatment variable and its
319+ counterfactual.
320+ """
321+ if self ._context .rf_tensors .prior_reach_scaled_counterfactual is None :
322+ return rf_transformed
323+ rf_transformed_counterfactual = self .adstock_hill_rf (
324+ reach = self ._context .rf_tensors .prior_reach_scaled_counterfactual ,
325+ frequency = self ._context .rf_tensors .frequency ,
326+ alpha = alpha_rf ,
327+ ec = ec_rf ,
328+ slope = slope_rf ,
329+ decay_functions = self ._context .adstock_decay_spec .rf ,
330+ )
331+ # Absolute values is needed because the difference is negative for mROI
332+ # priors and positive for ROI and contribution priors.
333+ return backend .absolute (rf_transformed - rf_transformed_counterfactual )
263334
264335 def calculate_beta_x (
265336 self ,
337+ * ,
266338 is_non_media : bool ,
267339 incremental_outcome_x : backend .Tensor ,
268340 linear_predictor_counterfactual_difference : backend .Tensor ,
269341 eta_x : backend .Tensor ,
270342 beta_gx_dev : backend .Tensor ,
271343 ) -> backend .Tensor :
272- raise NotImplementedError
344+ """Calculates coefficient mean parameter for any treatment variable type.
345+
346+ The "beta_x" in the function name refers to the coefficient mean parameter
347+ of any treatment variable. The "x" can represent "m", "rf", "om", or "orf".
348+ This function can also be used to calculate "gamma_n" for any non-media
349+ treatments.
350+
351+ Args:
352+ is_non_media: Boolean indicating whether the treatment variable is a
353+ non-media treatment. This argument is used to determine whether the
354+ coefficient random effects are normal or log-normal. If `True`, then
355+ random effects are assumed to be normal. Otherwise, the distribution is
356+ inferred from `self._context.media_effects_dist`.
357+ incremental_outcome_x: The incremental outcome of the treatment variable,
358+ which depends on the parameter values of a particular prior or posterior
359+ draw. The "_x" indicates that this is a tensor with length equal to the
360+ dimension of the treatment variable.
361+ linear_predictor_counterfactual_difference: The difference between the
362+ treatment variable and its counterfactual on the linear predictor scale.
363+ "Linear predictor" refers to the quantity that is multiplied by the
364+ geo-level coefficient. For media variables, this is the output of the
365+ hill/adstock transformation function. For non-media treatments, this is
366+ simply the treatment variable after centering/scaling transformations.
367+ This tensor has dimensions for geo, time, and channel.
368+ eta_x: The random effect standard deviation parameter values. For media
369+ variables, the "x" represents "m", "rf", "om", or "orf". For non-media
370+ treatments, this argument should be set to `xi_n`, which is analogous to
371+ "eta".
372+ beta_gx_dev: The latent standard normal parameter values of the geo-level
373+ coefficients. For media variables, the "x" represents "m", "rf", "om",
374+ or "orf". For non-media treatments, this argument should be set to
375+ `gamma_gn_dev`, which is analogous to "beta_gx_dev".
376+
377+ Returns:
378+ The coefficient mean parameter of the treatment variable, which has
379+ dimension equal to the number of treatment channels..
380+ """
381+ if is_non_media :
382+ random_effects_normal = True
383+ else :
384+ random_effects_normal = (
385+ self ._context .media_effects_dist == constants .MEDIA_EFFECTS_NORMAL
386+ )
387+ if self ._context .revenue_per_kpi is None :
388+ revenue_per_kpi = backend .ones (
389+ [self ._context .n_geos , self ._context .n_times ], dtype = backend .float32
390+ )
391+ else :
392+ revenue_per_kpi = self ._context .revenue_per_kpi
393+ incremental_outcome_gx_over_beta_gx = backend .einsum (
394+ "...gtx,gt,g,->...gx" ,
395+ linear_predictor_counterfactual_difference ,
396+ revenue_per_kpi ,
397+ self ._context .population ,
398+ self ._context .kpi_transformer .population_scaled_stdev ,
399+ )
400+ if random_effects_normal :
401+ numerator_term_x = backend .einsum (
402+ "...gx,...gx,...x->...x" ,
403+ incremental_outcome_gx_over_beta_gx ,
404+ beta_gx_dev ,
405+ eta_x ,
406+ )
407+ denominator_term_x = backend .einsum (
408+ "...gx->...x" , incremental_outcome_gx_over_beta_gx
409+ )
410+ return (incremental_outcome_x - numerator_term_x ) / denominator_term_x
411+ # For log-normal random effects, beta_x and eta_x are not mean & std.
412+ # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
413+ denominator_term_x = backend .einsum (
414+ "...gx,...gx->...x" ,
415+ incremental_outcome_gx_over_beta_gx ,
416+ backend .exp (beta_gx_dev * eta_x [..., backend .newaxis , :]),
417+ )
418+ return backend .log (incremental_outcome_x ) - backend .log (denominator_term_x )
0 commit comments