Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ def __init__(
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration.
Class-default defined by the user default_model_config method.
A ``UserWarning`` is raised for any key not present in
``default_model_config``, since such keys are ignored by the model.
sampler_config : Dictionary, optional
dictionary of parameters that initialise sampler configuration.
Class-default defined by the user default_sampler_config method.
Expand All @@ -615,10 +617,23 @@ class MyModel(ModelBuilder): ...
self.sampler_config = (
self.default_sampler_config | sampler_config
) # Parameters for fit sampling
default_model_config = self.default_model_config
self.model_config = (
self.default_model_config | model_config
default_model_config | model_config
) # parameters for priors etc.

# Warn about model_config keys that the model does not use, so that
# typos (e.g. "alphaa" instead of "alpha") don't silently get ignored.
unused_model_config_keys = set(model_config) - set(default_model_config)
if unused_model_config_keys:
warnings.warn(
"The following model_config keys are not used by the model "
f"and will be ignored: {sorted(unused_model_config_keys)}. "
f"Valid keys are: {sorted(default_model_config)}.",
UserWarning,
stacklevel=2,
)

self.model: pm.Model
self.idata: az.InferenceData | None = None # idata is generated during fitting
self.is_fitted_ = False
Expand Down
12 changes: 12 additions & 0 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,18 @@ def test_model_configuration(model_class, expected_type, test_config):
assert nondefault.sampler_config == default.sampler_config | {"draws": 42}


def test_model_config_warns_on_unused_keys():
"""Unknown model_config keys should warn so typos are not silently ignored."""
with pytest.warns(UserWarning, match="not used by the model"):
ModelBuilderTest(model_config={"mu_loc": 5, "typo_key": 1})


def test_model_config_no_warning_for_valid_keys(recwarn):
"""No unused-key warning is raised when every key is a valid default key."""
ModelBuilderTest(model_config={"mu_loc": 5})
assert not [w for w in recwarn if "not used by the model" in str(w.message)]


@pytest.mark.parametrize(
"test_case,model_class,method,expected_error,args",
[
Expand Down
Loading