diff --git a/lightweight_mmm/lightweight_mmm.py b/lightweight_mmm/lightweight_mmm.py index b22d2cc..54e88d8 100644 --- a/lightweight_mmm/lightweight_mmm.py +++ b/lightweight_mmm/lightweight_mmm.py @@ -64,7 +64,8 @@ _NAMES_TO_MODEL_TRANSFORMS = immutabledict.immutabledict({ "hill_adstock": models.transform_hill_adstock, "adstock": models.transform_adstock, - "carryover": models.transform_carryover + "carryover": models.transform_carryover, + "hill_carryover": models.transform_hill_carryover }) _MODEL_FUNCTION = models.media_mix_model diff --git a/lightweight_mmm/models.py b/lightweight_mmm/models.py index 70158b3..1b22d10 100644 --- a/lightweight_mmm/models.py +++ b/lightweight_mmm/models.py @@ -87,7 +87,9 @@ def __call__( "adstock": frozenset((_EXPONENT, _LAG_WEIGHT)), "hill_adstock": - frozenset((_LAG_WEIGHT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE)) + frozenset((_LAG_WEIGHT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE)), + "hill_carryover": + frozenset((_AD_EFFECT_RETENTION_RATE, _PEAK_EFFECT_DELAY, _EXPONENT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE)), }) GEO_ONLY_PRIORS = frozenset((_COEF_SEASONALITY,)) @@ -134,8 +136,21 @@ def _get_transform_default_priors() -> Mapping[str, Prior]: dist.Gamma(concentration=1., rate=1.), _SLOPE: dist.Gamma(concentration=1., rate=1.) + }), + "hill_carryover": + immutabledict.immutabledict({ + _AD_EFFECT_RETENTION_RATE: + dist.Beta(concentration1=1., concentration0=1.), + _PEAK_EFFECT_DELAY: + dist.HalfNormal(scale=2.), + _EXPONENT: + dist.Beta(concentration1=9., concentration0=1.), + _HALF_MAX_EFFECTIVE_CONCENTRATION: + dist.Gamma(concentration=1., rate=1.), + _SLOPE: + dist.Gamma(concentration=1., rate=1.) + }), }) - }) def transform_adstock(media_data: jnp.ndarray, @@ -280,6 +295,75 @@ def transform_carryover(media_data: jnp.ndarray, return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent) +def transform_hill_carryover(media_data: jnp.ndarray, + custom_priors: MutableMapping[str, Prior], + number_lags: int = 13) -> jnp.ndarray: + + """Transforms the input data with the carryover and hill function. + + Args: + media_data: Media data to be transformed. It is expected to have 2 dims for + national models and 3 for geo models. + custom_priors: The custom priors we want the model to take instead of the + default ones. The possible names of parameters for carryover and exponent + are "ad_effect_retention_rate_plate", "peak_effect_delay_plate" and + "exponent". + number_lags: Number of lags for the carryover function. + + Returns: + The transformed media data. + """ + transform_default_priors = _get_transform_default_priors()["hill_carryover"] + with numpyro.plate(name=f"{_HALF_MAX_EFFECTIVE_CONCENTRATION}_plate", + size=media_data.shape[1]): + half_max_effective_concentration = numpyro.sample( + name=_HALF_MAX_EFFECTIVE_CONCENTRATION, + fn=custom_priors.get( + _HALF_MAX_EFFECTIVE_CONCENTRATION, + transform_default_priors[_HALF_MAX_EFFECTIVE_CONCENTRATION])) + + with numpyro.plate(name=f"{_SLOPE}_plate", + size=media_data.shape[1]): + slope = numpyro.sample( + name=_SLOPE, + fn=custom_priors.get(_SLOPE, transform_default_priors[_SLOPE])) + + with numpyro.plate(name=f"{_AD_EFFECT_RETENTION_RATE}_plate", + size=media_data.shape[1]): + ad_effect_retention_rate = numpyro.sample( + name=_AD_EFFECT_RETENTION_RATE, + fn=custom_priors.get( + _AD_EFFECT_RETENTION_RATE, + transform_default_priors[_AD_EFFECT_RETENTION_RATE])) + + with numpyro.plate(name=f"{_PEAK_EFFECT_DELAY}_plate", + size=media_data.shape[1]): + peak_effect_delay = numpyro.sample( + name=_PEAK_EFFECT_DELAY, + fn=custom_priors.get( + _PEAK_EFFECT_DELAY, transform_default_priors[_PEAK_EFFECT_DELAY])) + + with numpyro.plate(name=f"{_EXPONENT}_plate", + size=media_data.shape[1]): + exponent = numpyro.sample( + name=_EXPONENT, + fn=custom_priors.get(_EXPONENT, + transform_default_priors[_EXPONENT])) + + half_max_effective_concentration = jnp.array(half_max_effective_concentration) + slope = jnp.array(slope) + carryover = media_transforms.hill(media_transforms.carryover( + data=media_data, + ad_effect_retention_rate=ad_effect_retention_rate, + peak_effect_delay=peak_effect_delay, + number_lags=number_lags),half_max_effective_concentration=half_max_effective_concentration, + slope=slope) + + if media_data.ndim == 3: + exponent = jnp.expand_dims(exponent, axis=-1) + return carryover + + def media_mix_model( media_data: jnp.ndarray, target_data: jnp.ndarray,