Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
NoAdstock,
WeibullCDFAdstock,
WeibullPDFAdstock,
adstock_from_dict,
)
from pymc_marketing.mmm.components.saturation import (
HillSaturation,
Expand All @@ -35,7 +34,6 @@
SaturationTransformation,
TanhSaturation,
TanhSaturationBaselined,
saturation_from_dict,
)
from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier
from pymc_marketing.mmm.hsgp import (
Expand Down Expand Up @@ -122,7 +120,6 @@
"WeibullCDFAdstock",
"WeibullPDFAdstock",
"YearlyFourier",
"adstock_from_dict",
"approx_hsgp_hyperparams",
"create_complexity_penalizing_prior",
"create_constrained_inverse_gamma_prior",
Expand All @@ -131,7 +128,6 @@
"preprocessing",
"preprocessing_method_X",
"preprocessing_method_y",
"saturation_from_dict",
"validating",
"validation_method_X",
"validation_method_y",
Expand Down
48 changes: 0 additions & 48 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def function(self, x, alpha):

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -144,9 +143,6 @@ def from_dict(cls, data: dict) -> AdstockTransformation:
"""Reconstruct an adstock transformation from a dict."""
data = data.copy()
data.pop("__type__", None)
data.pop(
"lookup_name", None
) # TODO(1.0): Remove once Legacy MMM is removed (#2430)

if "priors" in data:
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}
Expand Down Expand Up @@ -455,47 +451,3 @@ def function(self, x, *, dim: str | None = None):
def update_priors(self, priors):
"""Update priors for the no adstock transformation."""
return


# TODO(1.0): Remove this dict once Legacy MMM is removed (see #2430)
ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {
"geometric": GeometricAdstock,
"delayed": DelayedAdstock,
"weibull_cdf": WeibullCDFAdstock,
"weibull_pdf": WeibullPDFAdstock,
"binomial": BinomialAdstock,
"no_adstock": NoAdstock,
}


def adstock_from_dict(data: dict) -> AdstockTransformation:
"""Create an adstock transformation from a dictionary.

.. deprecated:: 0.18.2
`adstock_from_dict` is deprecated and will be removed in 0.20.0.
Use ``from pymc_marketing.serialization import serialization; serialization.deserialize(data)`` instead.
"""
warnings.warn(
"adstock_from_dict is deprecated and will be removed in 0.20.0. "
"Use `from pymc_marketing.serialization import serialization; "
"serialization.deserialize(data)` instead.",
FutureWarning,
stacklevel=2,
)
data = data.copy()
type_key = data.pop("__type__", None)
lookup_name = data.pop("lookup_name", None)

if lookup_name:
cls = ADSTOCK_TRANSFORMATIONS[lookup_name]
elif type_key:
return serialization.deserialize({**data, "__type__": type_key})
else:
raise ValueError(
"Cannot deserialize adstock: missing both 'lookup_name' and '__type__'"
)

if "priors" in data:
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}

return cls(**data)
54 changes: 0 additions & 54 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def function(self, x, b):

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -173,13 +172,8 @@ def from_dict(cls, data: dict) -> SaturationTransformation:
"""Reconstruct a saturation transformation from a dict."""
data = data.copy()
data.pop("__type__", None)
data.pop(
"lookup_name", None
) # TODO(1.0): Remove once Legacy MMM is removed (#2430)

if "priors" in data:
from pymc_extras.deserialize import deserialize

data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}

return cls(**data)
Expand Down Expand Up @@ -632,51 +626,3 @@ def function(self, x, beta, *, dim: str | None = None):
return beta * x

default_priors = {"beta": Prior("HalfNormal", sigma=1)}


# TODO(1.0): Remove this dict once Legacy MMM is removed (see #2430)
SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
"logistic": LogisticSaturation,
"inverse_scaled_logistic": InverseScaledLogisticSaturation,
"tanh": TanhSaturation,
"tanh_baselined": TanhSaturationBaselined,
"michaelis_menten": MichaelisMentenSaturation,
"hill": HillSaturation,
"hill_sigmoid": HillSaturationSigmoid,
"root": RootSaturation,
"no_saturation": NoSaturation,
}


def saturation_from_dict(data: dict) -> SaturationTransformation:
"""Get a saturation function from a dictionary.

.. deprecated:: 0.18.2
`saturation_from_dict` is deprecated and will be removed in 0.20.0.
Use ``from pymc_marketing.serialization import serialization; serialization.deserialize(data)`` instead.
"""
warnings.warn(
"saturation_from_dict is deprecated and will be removed in 0.20.0. "
"Use `from pymc_marketing.serialization import serialization; "
"serialization.deserialize(data)` instead.",
FutureWarning,
stacklevel=2,
)
data = data.copy()
type_key = data.pop("__type__", None)
lookup_name = data.pop("lookup_name", None)

if lookup_name:
cls = SATURATION_TRANSFORMATIONS[lookup_name]
elif type_key:
return serialization.deserialize({**data, "__type__": type_key})
else:
raise ValueError(
"Cannot deserialize saturation: missing both 'lookup_name' and '__type__'"
)

if "priors" in data:
data["priors"] = {
key: deserialize(value) for key, value in data["priors"].items()
}
return cls(**data)
27 changes: 4 additions & 23 deletions pymc_marketing/mmm/media_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,8 @@
from pymc.distributions.shape_utils import Dims
from pytensor.xtensor.type import XTensorVariable

from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
adstock_from_dict,
)
from pymc_marketing.mmm.components.saturation import (
SaturationTransformation,
saturation_from_dict,
)
from pymc_marketing.mmm.components.adstock import AdstockTransformation
from pymc_marketing.mmm.components.saturation import SaturationTransformation
from pymc_marketing.serialization import serialization


Expand Down Expand Up @@ -246,22 +240,9 @@ def from_dict(cls, data) -> MediaTransformation:
The media transformation created from the dictionary.

"""
adstock_data = data["adstock"]
saturation_data = data["saturation"]

if "__type__" in adstock_data:
adstock = serialization.deserialize(adstock_data)
else:
adstock = adstock_from_dict(adstock_data)

if "__type__" in saturation_data:
saturation = serialization.deserialize(saturation_data)
else:
saturation = saturation_from_dict(saturation_data)

return cls(
adstock=adstock,
saturation=saturation,
adstock=serialization.deserialize(data["adstock"]),
saturation=serialization.deserialize(data["saturation"]),
adstock_first=data["adstock_first"],
dims=data.get("dims"),
)
Expand Down
114 changes: 4 additions & 110 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings

import numpy as np
import pymc as pm
import pytest
import xarray as xr
from pydantic import ValidationError
from pymc_extras.deserialize import (
DESERIALIZERS,
register_deserialization,
)
from pymc_extras.prior import Prior
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.type import XTensorVariable

import pymc_marketing.mmm.components.adstock as adstock_module
from pymc_marketing.mmm.components.adstock import (
ADSTOCK_TRANSFORMATIONS,
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
NoAdstock,
adstock_from_dict,
)
from pymc_marketing.mmm.transformers import ConvMode
from pymc_marketing.serialization import serialization
Expand All @@ -47,12 +40,10 @@


def adstocks() -> list:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return [
pytest.param(adstock(l_max=10), id=name)
for name, adstock in ADSTOCK_TRANSFORMATIONS.items()
]
return [
pytest.param(adstock_cls(l_max=10), id=adstock_cls.__name__)
for adstock_cls in ALL_ADSTOCK_CLASSES
]


@pytest.fixture
Expand Down Expand Up @@ -116,57 +107,6 @@ def test_adstock_sample_curve(adstock: AdstockTransformation) -> None:
assert curve.shape == (1, 500, adstock.l_max)


def test_adstock_from_dict() -> None:
data = {
"lookup_name": "geometric",
"l_max": 10,
"prefix": "test",
"mode": "Before",
"priors": {
"alpha": {
"dist": "Beta",
"kwargs": {
"alpha": 1,
"beta": 2,
},
},
},
}

with pytest.warns(FutureWarning, match="adstock_from_dict is deprecated"):
adstock = adstock_from_dict(data)
assert adstock == GeometricAdstock(
l_max=10,
prefix="test",
priors={
"alpha": Prior("Beta", alpha=1, beta=2),
},
mode=ConvMode.Before,
)


@pytest.mark.parametrize(
"lookup_name, adstock_cls",
list(ADSTOCK_TRANSFORMATIONS.items()),
)
def test_adstock_from_dict_without_priors(
lookup_name: str,
adstock_cls: type[AdstockTransformation],
) -> None:
data = {
"lookup_name": lookup_name,
"l_max": 10,
"prefix": "test",
"mode": "Before",
}

with pytest.warns(FutureWarning, match="adstock_from_dict is deprecated"):
adstock = adstock_from_dict(data)
assert adstock.default_priors == {
k: Prior.from_dict(v) for k, v in adstock.to_dict()["priors"].items()
}


def test_repr() -> None:
assert repr(GeometricAdstock(l_max=10)) == (
"GeometricAdstock(prefix='adstock', l_max=10, "
Expand All @@ -177,52 +117,6 @@ def test_repr() -> None:
)


class ArbitraryObject:
def __init__(self, msg: str, value: int) -> None:
self.msg = msg
self.value = value
self.dims = ()

def create_variable(self, name: str):
return pm.Normal(name, mu=0, sigma=1)


@pytest.fixture
def register_arbitrary_deserialization():
register_deserialization(
lambda data: isinstance(data, dict) and data.keys() == {"msg", "value"},
lambda data: ArbitraryObject(**data),
)

yield

DESERIALIZERS.pop()


def test_deserialization(
register_arbitrary_deserialization,
) -> None:
data = {
"lookup_name": "geometric",
"prefix": "new",
"l_max": 10,
"priors": {
"alpha": {"msg": "hello", "value": 1},
},
}

with pytest.warns(FutureWarning, match="adstock_from_dict is deprecated"):
instance = adstock_from_dict(data)
assert isinstance(instance, GeometricAdstock)
assert instance.prefix == "new"
assert instance.l_max == 10

alpha = instance.function_priors["alpha"]
assert isinstance(alpha, ArbitraryObject)
assert alpha.msg == "hello"
assert alpha.value == 1


class TestAdstockRoundtrips:
"""Every AdstockTransformation subclass round-trips with all params."""

Expand Down
Loading