Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 104 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
# don't add a return type section, use standard return with type info
typehints_document_rtype = False


# intersphinx configuration to ease linking arviz docs
intersphinx_mapping = {
"arviz": ("https://python.arviz.org/en/latest/", None),
Expand Down Expand Up @@ -337,10 +338,111 @@ def scrub_plotly_mathjax(app, pagename, templatename, context, doctree):
)


def _resolve_field_description(field) -> str:
"""Return ``field.description``, looking into ``Annotated`` metadata."""
import typing

from pydantic.fields import FieldInfo

if field.description:
return field.description

def walk(annotation) -> str | None:
if annotation is None:
return None
for arg in typing.get_args(annotation):
if isinstance(arg, FieldInfo) and arg.description:
return arg.description
found = walk(arg)
if found:
return found
return None

return walk(field.annotation) or ""


def _format_field_type(annotation) -> str:
"""Render a pydantic field annotation as a numpydoc-style type string."""
import types
import typing

if annotation is None or annotation is type(None):
return "None"

origin = typing.get_origin(annotation)
args = typing.get_args(annotation)

# Annotated[X, ...]: use the base type only, drop FieldInfo metadata.
if origin is typing.get_origin(typing.Annotated[int, "x"]):
return _format_field_type(args[0])

# Union / X | None: render as "X, optional" when the only extra is None.
if origin in (types.UnionType, typing.Union):
non_none = [a for a in args if a is not type(None)]
if len(non_none) == 1 and len(args) != len(non_none):
return f"{_format_field_type(non_none[0])}, optional"
return " or ".join(_format_field_type(a) for a in args)

if isinstance(annotation, type):
return annotation.__name__

return str(annotation).replace("typing.", "")


_PYDANTIC_PARAMETERS_INJECT_TARGETS = frozenset(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we need this?

{
"pymc_marketing.hsgp_kwargs.HSGPKwargs",
"pymc_marketing.mmm.scaling.VariableScaling",
"pymc_marketing.mmm.scaling.DataDerivedScaling",
"pymc_marketing.mmm.scaling.FixedScaling",
"pymc_marketing.mmm.scaling.Scaling",
}
)


def inject_pydantic_parameters(app, what, name, obj, options, lines):
"""Append a numpydoc ``Parameters`` block derived from ``model_fields``.

Source of truth for the listed classes is ``Field(description=...)``.
This callback turns those descriptions into the same Parameters block
numpydoc renders for regular classes, so they appear in the published
docs without keeping a duplicated block in the docstring.

Scoped to the targets in ``_PYDANTIC_PARAMETERS_INJECT_TARGETS``: other
pydantic models in the repo have hand-written Parameters blocks or
orphan parameter-like lists that this injector would clash with.
"""
from pydantic_core import PydanticUndefined

if what != "class" or name not in _PYDANTIC_PARAMETERS_INJECT_TARGETS:
return
if not getattr(obj, "model_fields", None):
return

# Emit sphinx-domain :param: / :type: directives, which docutils renders
# as a "Parameters" field list. Numpydoc has already run on the original
# docstring by the time this callback fires, so producing numpydoc-style
# "Parameters\n----------" here would be parsed as a section header
# instead of a field list.
block = [""]
for fname, field in obj.model_fields.items():
description = _resolve_field_description(field).strip()
if field.default is not PydanticUndefined and field.default is not None:
description = f"{description} Default is ``{field.default!r}``.".strip()
type_str = _format_field_type(field.annotation)
desc_lines = description.splitlines() or [""]
block.append(f":param {fname}: {desc_lines[0].strip()}")
for cont in desc_lines[1:]:
block.append(f" {cont.strip()}")
block.append(f":type {fname}: {type_str}")
lines.extend(block)


def setup(app):
"""Configure Sphinx application event handlers.

Connects the Plotly MathJax scrubbing function to the html-page-context event.
Connects the Plotly MathJax scrubbing function to the html-page-context
event and the pydantic-parameters injector to autodoc-process-docstring.
"""
# Connect the scrubbing function to the html-page-context event
app.connect("html-page-context", scrub_plotly_mathjax)
app.connect("autodoc-process-docstring", inject_pydantic_parameters)
24 changes: 5 additions & 19 deletions pymc_marketing/hsgp_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,6 @@ class HSGPKwargs(BaseModel):
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_.
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_.
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_.

Parameters
----------
m : int
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the rendering here:
https://pymc-marketing--2572.org.readthedocs.build/en/2572/api/generated/pymc_marketing.hsgp_kwargs.HSGPKwargs.html
is not amazing (missing periods / no line breaks, etc)

Number of basis functions. Default is 200.
L : float, optional
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP).
By default it is None.
eta_lam : float
Exponential prior for the variance. Default is 1.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale. Default is 5.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5.
cov_func : CovFunc, optional
Covariance function enum. Supported values: ``ExpQuad``, ``Matern52``, ``Matern32``.
By default it is None (resolved to ``Matern52`` at model-build time).
""" # noqa E501

m: int = Field(200, description="Number of basis functions")
Expand All @@ -90,7 +72,11 @@ class HSGPKwargs(BaseModel):
description="Standard deviation of the inverse gamma prior for the lengthscale",
)
cov_func: CovFunc | None = Field(
None, description="Covariance function enum (ExpQuad, Matern52, Matern32)"
None,
description=(
"Covariance function enum. Supported values: ``ExpQuad``, ``Matern52``, "
"``Matern32``. ``None`` is resolved to ``Matern52`` at model-build time."
),
)

def to_dict(self) -> dict[str, Any]:
Expand Down
57 changes: 16 additions & 41 deletions pymc_marketing/mmm/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,14 @@ class VariableScaling(SerializableBaseModel, ABC):
(``"max"`` or ``"mean"``), computed at fit time.
- :class:`FixedScaling` -- use a user-supplied constant that stays the
same across model refreshes.

Parameters
----------
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).
"""

dims: str | tuple[str, ...] = Field(
...,
description="The dimensions to perform operation through.",
description=(
"The dimensions to perform the operation through "
'(``"date"`` is always included implicitly).'
),
)

@abstractmethod
Expand All @@ -137,14 +134,6 @@ def _validate_dims(self) -> Self:
class DataDerivedScaling(VariableScaling):
"""Scale by a statistic of the data, computed at fit time.

Parameters
----------
method : ``"max"`` | ``"mean"``
The scaling method.
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).

Examples
--------
Max-absolute scaling (default behaviour):
Expand All @@ -170,22 +159,6 @@ def scaling_description(self) -> str:
class FixedScaling(VariableScaling):
"""Use a user-supplied constant that stays the same across model refreshes.

Parameters
----------
dims : str or tuple of str
The dimensions to perform the operation through (``"date"`` is always
included implicitly).
value : float or dict[str, float] or xarray.DataArray
Fixed scaling constant(s). A single ``float`` applies uniformly.

A ``dict`` maps **coordinate labels along the single remaining
dimension** after reducing over ``date`` and ``dims`` (see the
multidimensional MMM). If more than one non-reduced dimension remains,
use an :class:`xarray.DataArray` whose dimensions broadcast to that
grid (e.g. a vector over ``country`` when the media grid is
``country`` × ``channel``). All values must be positive; NaNs are not
allowed.

Examples
--------
Fixed scalar scaling for production stability:
Expand Down Expand Up @@ -234,7 +207,16 @@ class FixedScaling(VariableScaling):

value: float | dict[str, float] | xr.DataArray = Field(
...,
description="Fixed scaling constant(s). All values must be positive.",
description=(
"Fixed scaling constant(s). A single ``float`` applies uniformly. "
"A ``dict`` maps **coordinate labels along the single remaining "
"dimension** after reducing over ``date`` and ``dims`` (see the "
"multidimensional MMM). If more than one non-reduced dimension "
"remains, use an :class:`xarray.DataArray` whose dimensions "
"broadcast to that grid (e.g. a vector over ``country`` when the "
"media grid is ``country`` × ``channel``). All values must be "
"positive; NaNs are not allowed."
),
)

@property
Expand Down Expand Up @@ -404,13 +386,6 @@ def deserialize_variable_scaling(d: dict[str, Any]) -> VariableScaling:
class Scaling(SerializableBaseModel):
"""Scaling configuration for the MMM.

Parameters
----------
target : VariableScaling
Scaling configuration for the target (response) variable.
channel : VariableScaling
Scaling configuration for the channel (media) variables.

Examples
--------
Data-derived scaling:
Expand All @@ -434,11 +409,11 @@ class Scaling(SerializableBaseModel):

target: VariableScaling = Field(
...,
description="The scaling for the target variable.",
description="Scaling configuration for the target (response) variable.",
)
channel: VariableScaling = Field(
...,
description="The scaling for the channel variable.",
description="Scaling configuration for the channel (media) variables.",
)

@model_validator(mode="before")
Expand Down
Loading