Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/source/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@
"model_validate", "model_validate_json", "model_validate_strings",
"model_json_schema", "model_post_init", "model_rebuild",
"predict_stream", "load_context",
"construct", "copy", "dict", "from_orm", "json",
"parse_obj", "parse_raw", "parse_file",
"schema", "schema_json", "update_forward_refs", "validate",
"model_parametrized_name",
] %}
{# Pydantic models have their fields documented inline by autopydantic_model
(see the source-read hook in conf.py), so the Attributes summary table is
skipped to avoid duplicating user fields and listing pydantic internals
(model_config, model_fields, model_extra, ...). Detect pydantic by the
presence of `model_fields` in the attributes list, which is reliable in
pydantic v2 and never present on non-pydantic classes. #}
{% set is_pydantic_model = "model_fields" in attributes %}
{{ name | escape | underline}}

.. currentmodule:: {{ module }}
Expand All @@ -48,7 +59,7 @@
{% endblock %}

{% block attributes %}
{% if attributes %}
{% if attributes and not is_pydantic_model %}
.. rubric:: Attributes

.. autosummary::
Expand Down
66 changes: 62 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"sphinx.ext.intersphinx",
# extensions provided by other packages
"sphinx_autodoc_typehints",
"sphinxcontrib.autodoc_pydantic",
"numpydoc",
"matplotlib.sphinxext.plot_directive", # needed to plot in docstrings
"myst_nb",
Expand Down Expand Up @@ -174,6 +175,17 @@
# don't add a return type section, use standard return with type info
typehints_document_rtype = False

# autodoc-pydantic: render Field(description=...) on pydantic models. Keep
# the page free of pydantic internals (JSON schema, Config, validators);
# those add noise without helping the reader.
autodoc_pydantic_model_show_json = False
autodoc_pydantic_model_show_config_summary = False
autodoc_pydantic_model_show_validator_summary = False
autodoc_pydantic_model_show_validator_members = False
autodoc_pydantic_model_show_field_summary = False
autodoc_pydantic_field_list_validators = False


# intersphinx configuration to ease linking arviz docs
intersphinx_mapping = {
"arviz": ("https://python.arviz.org/en/latest/", None),
Expand Down Expand Up @@ -337,10 +349,56 @@ def scrub_plotly_mathjax(app, pagename, templatename, context, doctree):
)


def setup(app):
"""Configure Sphinx application event handlers.
def use_autopydantic_for_pydantic_models(app, docname, source):
"""Swap ``autoclass`` for ``autopydantic_model`` in pydantic autosummary stubs.

Connects the Plotly MathJax scrubbing function to the html-page-context event.
Sphinx 9.x's autosummary picks the template via hard-coded heuristics
(`_best_object_type_for_member`) that don't respect autodoc-pydantic's
``PydanticModelDocumenter`` priority, so generated stubs use
``autoclass``. Rewriting at source-read time lets autodoc-pydantic
render fields (the goal of #1700) for every pydantic model in the repo.
"""
# Connect the scrubbing function to the html-page-context event
if not docname.startswith("api/generated/"):
return
text = source[0]
if ".. autoclass:: " not in text or "autopydantic_model" in text:
return

import importlib
import re

cm = re.search(r"\.\. currentmodule:: (\S+)", text)
ac = re.search(r"\.\. autoclass:: (\S+)", text)
if not cm or not ac:
return
module_name, class_name = cm.group(1), ac.group(1)

try:
mod = importlib.import_module(module_name)
cls = getattr(mod, class_name, None)
from pydantic import BaseModel

if (
isinstance(cls, type)
and cls is not BaseModel
and issubclass(cls, BaseModel)
):
# :inherited-members: BaseModel pulls fields defined on parent
# pydantic models (e.g. ``dims`` on ``VariableScaling`` shows
# up on ``FixedScaling``) without leaking BaseModel internals.
source[0] = text.replace(
f".. autoclass:: {class_name}",
(
f".. autopydantic_model:: {class_name}\n"
" :inherited-members: BaseModel"
),
1,
)
except Exception:
return


def setup(app):
"""Configure Sphinx application event handlers."""
app.connect("html-page-context", scrub_plotly_mathjax)
app.connect("source-read", use_autopydantic_for_pydantic_models)
57 changes: 20 additions & 37 deletions pymc_marketing/hsgp_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Class to store and validate keyword argument for the Hilbert Space Gaussian Process (HSGP) components."""

from enum import StrEnum
from typing import Annotated, Any
from typing import Any

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -46,51 +46,34 @@ class HSGPKwargs(BaseModel):
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_.
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_.
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_.

Parameters
----------
m : int
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the rendering here:
https://pymc-marketing--2572.org.readthedocs.build/en/2572/api/generated/pymc_marketing.hsgp_kwargs.HSGPKwargs.html
is not amazing (missing periods / no line breaks, etc)

Number of basis functions. Default is 200.
L : float, optional
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP).
By default it is None.
eta_lam : float
Exponential prior for the variance. Default is 1.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale. Default is 5.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5.
cov_func : CovFunc, optional
Covariance function enum. Supported values: ``ExpQuad``, ``Matern52``, ``Matern32``.
By default it is None (resolved to ``Matern52`` at model-build time).
""" # noqa E501

m: int = Field(200, description="Number of basis functions")
L: (
Annotated[
float,
Field(
gt=0,
description="""
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP)
""",
),
]
| None
) = None
eta_lam: float = Field(1.0, gt=0, description="Exponential prior for the variance")
m: int = Field(200, description="Number of basis functions.")
L: float | None = Field(
None,
gt=0,
description=(
"Extent of basis functions. Set this to reflect the expected range "
"of in+out-of-sample data (considering that time-indices are "
"zero-centered). Defaults to ``X_mid * 2`` (identical to ``c=2`` "
"in HSGP)."
),
)
eta_lam: float = Field(1.0, gt=0, description="Exponential prior for the variance.")
ls_mu: float = Field(
5.0, gt=0, description="Mean of the inverse gamma prior for the lengthscale"
5.0, gt=0, description="Mean of the inverse gamma prior for the lengthscale."
)
ls_sigma: float = Field(
5.0,
gt=0,
description="Standard deviation of the inverse gamma prior for the lengthscale",
description="Standard deviation of the inverse gamma prior for the lengthscale.",
)
cov_func: CovFunc | None = Field(
None, description="Covariance function enum (ExpQuad, Matern52, Matern32)"
None,
description=(
"Covariance function enum. Supported values: ``ExpQuad``, ``Matern52``, "
"``Matern32``. ``None`` is resolved to ``Matern52`` at model-build time."
),
)

def to_dict(self) -> dict[str, Any]:
Expand Down
9 changes: 5 additions & 4 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,10 @@ class BudgetOptimizer(BaseModel):
default=None,
description=(
"Fixed temporal distribution of each budget cell across periods. "
"Must have dims ('date', *budget_dims) where 'date' has length num_periods. "
"Values must sum to 1 along 'date' for every combination of the remaining dims. "
"If None, budget is distributed uniformly (1/num_periods per period)."
"Must have dims ``('date', *budget_dims)`` where 'date' has length "
"num_periods. Values must sum to 1 along 'date' for every combination "
"of the remaining dims. If None, budget is distributed uniformly "
"(1/num_periods per period)."
),
)

Expand All @@ -683,7 +684,7 @@ class BudgetOptimizer(BaseModel):
description=(
"Cost per unit conversion factors for converting budgets from "
"monetary units (dollars) to original units (impressions, clicks). "
"Must have dims (date, *budget_dims) where date has length "
"Must have dims ``(date, *budget_dims)`` where date has length "
"num_periods. If None, budgets are assumed to already be in "
"the model's native units (no conversion applied)."
),
Expand Down
57 changes: 16 additions & 41 deletions pymc_marketing/mmm/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,14 @@ class VariableScaling(SerializableBaseModel, ABC):
(``"max"`` or ``"mean"``), computed at fit time.
- :class:`FixedScaling` -- use a user-supplied constant that stays the
same across model refreshes.

Parameters
----------
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).
"""

dims: str | tuple[str, ...] = Field(
...,
description="The dimensions to perform operation through.",
description=(
"The dimensions to perform the operation through "
'(``"date"`` is always included implicitly).'
),
)

@abstractmethod
Expand All @@ -137,14 +134,6 @@ def _validate_dims(self) -> Self:
class DataDerivedScaling(VariableScaling):
"""Scale by a statistic of the data, computed at fit time.

Parameters
----------
method : ``"max"`` | ``"mean"``
The scaling method.
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).

Examples
--------
Max-absolute scaling (default behaviour):
Expand All @@ -170,22 +159,6 @@ def scaling_description(self) -> str:
class FixedScaling(VariableScaling):
"""Use a user-supplied constant that stays the same across model refreshes.

Parameters
----------
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).
value : float or dict[str, float] or xarray.DataArray
Fixed scaling constant(s). A single ``float`` applies uniformly.

A ``dict`` maps **coordinate labels along the single remaining
dimension** after reducing over ``date`` and ``dims`` (see the
multidimensional MMM). If more than one non-reduced dimension remains,
use an :class:`xarray.DataArray` whose dimensions broadcast to that
grid (e.g. a vector over ``country`` when the media grid is
``country`` × ``channel``). All values must be positive; NaNs are not
allowed.

Examples
--------
Fixed scalar scaling for production stability:
Expand Down Expand Up @@ -234,7 +207,16 @@ class FixedScaling(VariableScaling):

value: float | dict[str, float] | xr.DataArray = Field(
...,
description="Fixed scaling constant(s). All values must be positive.",
description=(
"Fixed scaling constant(s). A single ``float`` applies uniformly. "
"A ``dict`` maps **coordinate labels along the single remaining "
"dimension** after reducing over ``date`` and ``dims`` (see the "
"multidimensional MMM). If more than one non-reduced dimension "
"remains, use an :class:`xarray.DataArray` whose dimensions "
"broadcast to that grid (e.g. a vector over ``country`` when the "
"media grid is ``country`` × ``channel``). All values must be "
"positive; NaNs are not allowed."
),
)

@property
Expand Down Expand Up @@ -404,13 +386,6 @@ def deserialize_variable_scaling(d: dict[str, Any]) -> VariableScaling:
class Scaling(SerializableBaseModel):
"""Scaling configuration for the MMM.

Parameters
----------
target : VariableScaling
Scaling configuration for the target (response) variable.
channel : VariableScaling
Scaling configuration for the channel (media) variables.

Examples
--------
Data-derived scaling:
Expand All @@ -434,11 +409,11 @@ class Scaling(SerializableBaseModel):

target: VariableScaling = Field(
...,
description="The scaling for the target variable.",
description="Scaling configuration for the target (response) variable.",
)
channel: VariableScaling = Field(
...,
description="The scaling for the channel variable.",
description="Scaling configuration for the channel (media) variables.",
)

@model_validator(mode="before")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ docs = [
"myst-nb>=1.1.2",
"myst-parser",
"numba",
"autodoc-pydantic",
"numpydoc",
"numpyro",
"nutpie>=0.16.7",
Expand Down
Loading