Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]:

@dataclasses.dataclass(frozen=True)
class WithTag:
"""Filter that matches values with a string ``tag`` attribute equal to the given tag.

Used by ``RngKey`` and ``RngCount``. See
`Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.

Attributes:
tag: The string tag to match against a value's ``tag`` attribute.
"""
tag: str

def __call__(self, path: PathParts, x: tp.Any):
Expand All @@ -89,6 +97,15 @@ def __repr__(self):

@dataclasses.dataclass(frozen=True)
class PathContains:
"""Filter that matches values whose associated path contains the given key.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.

Attributes:
key: The key to look for in the value's path.
exact: If ``True``, ``key`` must equal a path element exactly. If ``False``,
matches when ``key`` is contained in the string form of any path element.
"""
key: Key | str
exact: bool = True

Expand Down Expand Up @@ -121,6 +138,14 @@ def __hash__(self):

@dataclasses.dataclass(frozen=True)
class OfType:
"""Filter that matches values that are instances of ``type``, or that have a
``type`` attribute that is an instance of ``type``.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.

Attributes:
type: The type to match values against.
"""
type: type

def __call__(self, path: PathParts, x: tp.Any):
Expand All @@ -131,6 +156,10 @@ def __repr__(self):


class Any:
"""Filter that matches values matching any of the inner filters.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
"""
def __init__(self, *filters: Filter):
self.predicates = tuple(
to_predicate(collection_filter) for collection_filter in filters
Expand All @@ -150,6 +179,10 @@ def __hash__(self):


class All:
"""Filter that matches values matching all of the inner filters.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
"""
def __init__(self, *filters: Filter):
self.predicates = tuple(
to_predicate(collection_filter) for collection_filter in filters
Expand All @@ -169,6 +202,10 @@ def __hash__(self):


class Not:
"""Filter that matches values that do not match the inner filter.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
"""
def __init__(self, collection_filter: Filter, /):
self.predicate = to_predicate(collection_filter)

Expand All @@ -186,6 +223,10 @@ def __hash__(self):


class Everything:
"""Filter that matches all values.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
"""
def __call__(self, path: PathParts, x: tp.Any):
return True

Expand All @@ -200,6 +241,10 @@ def __hash__(self):


class Nothing:
"""Filter that matches no values.

See `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
"""
def __call__(self, path: PathParts, x: tp.Any):
return False

Expand Down
Loading