From bdef692b1ec7f3e96e570474d3c38c2d3a71a7dd Mon Sep 17 00:00:00 2001 From: Santi Hernandez Date: Thu, 28 May 2026 22:06:34 -0600 Subject: [PATCH] Add docstrings to nnx filterlib filter classes --- flax/nnx/filterlib.py | 45 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 824eb4315..54229ede2 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -78,6 +78,14 @@ def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: @dataclasses.dataclass(frozen=True) class WithTag: + """Filter that matches values with a string ``tag`` attribute equal to the given tag. + + Used by ``RngKey`` and ``RngCount``. See + `Using Filters `__. + + Attributes: + tag: The string tag to match against a value's ``tag`` attribute. + """ tag: str def __call__(self, path: PathParts, x: tp.Any): @@ -89,6 +97,15 @@ def __repr__(self): @dataclasses.dataclass(frozen=True) class PathContains: + """Filter that matches values whose associated path contains the given key. + + See `Using Filters `__. + + Attributes: + key: The key to look for in the value's path. + exact: If ``True``, ``key`` must equal a path element exactly. If ``False``, + matches when ``key`` is contained in the string form of any path element. + """ key: Key | str exact: bool = True @@ -121,6 +138,14 @@ def __hash__(self): @dataclasses.dataclass(frozen=True) class OfType: + """Filter that matches values that are instances of ``type``, or that have a + ``type`` attribute that is an instance of ``type``. + + See `Using Filters `__. + + Attributes: + type: The type to match values against. + """ type: type def __call__(self, path: PathParts, x: tp.Any): @@ -131,6 +156,10 @@ def __repr__(self): class Any: + """Filter that matches values matching any of the inner filters. + + See `Using Filters `__. + """ def __init__(self, *filters: Filter): self.predicates = tuple( to_predicate(collection_filter) for collection_filter in filters @@ -150,6 +179,10 @@ def __hash__(self): class All: + """Filter that matches values matching all of the inner filters. + + See `Using Filters `__. + """ def __init__(self, *filters: Filter): self.predicates = tuple( to_predicate(collection_filter) for collection_filter in filters @@ -169,6 +202,10 @@ def __hash__(self): class Not: + """Filter that matches values that do not match the inner filter. + + See `Using Filters `__. + """ def __init__(self, collection_filter: Filter, /): self.predicate = to_predicate(collection_filter) @@ -186,6 +223,10 @@ def __hash__(self): class Everything: + """Filter that matches all values. + + See `Using Filters `__. + """ def __call__(self, path: PathParts, x: tp.Any): return True @@ -200,6 +241,10 @@ def __hash__(self): class Nothing: + """Filter that matches no values. + + See `Using Filters `__. + """ def __call__(self, path: PathParts, x: tp.Any): return False