Skip to content

Commit 1926665

Browse files
committed
Add deprecations for old NNX view functions
1 parent fbf5dbe commit 1926665

5 files changed

Lines changed: 82 additions & 4 deletions

File tree

flax/nnx/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .module import M as M
5151
from .module import Module as Module
5252
from .module import capture as capture
53-
from .module import with_modules as with_modules
53+
from .module import with_modules as with_modules, view as view
5454
from .module import view_info as view_info
5555
from .module import with_attributes as with_attributes
5656
from .module import iter_children as iter_children, iter_modules as iter_modules
@@ -75,8 +75,8 @@
7575
from .graphlib import MergeContext as MergeContext
7676
from .graphlib import merge_context as merge_context
7777
from .graphlib import variables as variables
78-
from .graphlib import with_vars as with_vars
79-
from .graphlib import as_pure as as_pure
78+
from .graphlib import with_vars as with_vars, vars_as as vars_as
79+
from .graphlib import as_pure as as_pure, pure as pure
8080
from .graphlib import cached_partial as cached_partial
8181
from .graphlib import flatten as flatten
8282
from .graphlib import unflatten as unflatten
@@ -152,7 +152,7 @@
152152
from .spmd import get_named_sharding as get_named_sharding
153153
from .spmd import with_partitioning as with_partitioning
154154
from .spmd import get_abstract_model as get_abstract_model
155-
from .spmd import as_abstract as as_abstract
155+
from .spmd import as_abstract as as_abstract, abstract_with_sharding as abstract_with_sharding
156156
from .statelib import FlatState as FlatState
157157
from .statelib import State as State
158158
from .statelib import to_flat_state as to_flat_state

flax/nnx/deprecations.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
import warnings
17+
from typing import TypeVar
18+
from collections.abc import Callable
19+
20+
F = TypeVar('F', bound=Callable)
21+
22+
23+
def deprecated(new_fn: F) -> F:
24+
"""Creates a deprecated alias for a renamed function.
25+
26+
.. deprecated::
27+
This decorator is for marking functions as deprecated. The returned
28+
wrapper emits a :class:`DeprecationWarning` on every call and then
29+
delegates to ``new_fn``.
30+
31+
The returned callable copies the signature, type annotations, and
32+
docstring of ``new_fn``, with a deprecation notice prepended to the
33+
docstring. This keeps IDE autocomplete and type-checking working while
34+
clearly communicating that callers should migrate.
35+
36+
Args:
37+
new_fn: The current, non-deprecated function to delegate to.
38+
39+
Returns:
40+
A wrapper that emits a :class:`DeprecationWarning` and then calls
41+
``new_fn`` with the same arguments.
42+
43+
Example::
44+
45+
>>> from flax.nnx.deprecations import deprecated
46+
>>> def new_api(x):
47+
... return x * 2
48+
>>> old_api = deprecated(new_api)
49+
>>> old_api(3) # emits DeprecationWarning: use new_api instead
50+
6
51+
52+
"""
53+
54+
@functools.wraps(new_fn)
55+
def wrapper(*args, **kwargs):
56+
warnings.warn(
57+
f'This function is deprecated. Use {new_fn.__qualname__} instead.',
58+
DeprecationWarning,
59+
stacklevel=2,
60+
)
61+
return new_fn(*args, **kwargs)
62+
63+
dep_notice = (
64+
f'.. deprecated::\n'
65+
f' Use :func:`{new_fn.__qualname__}` instead.\n\n'
66+
)
67+
wrapper.__doc__ = dep_notice + (new_fn.__doc__ or '')
68+
69+
return wrapper # type: ignore[return-value]

flax/nnx/graphlib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from flax import config
2727
from flax.nnx import filterlib, reprlib, traversals, variablelib
2828
from flax.nnx import statelib
29+
from flax.nnx.deprecations import deprecated
2930
from flax.nnx.proxy_caller import (
3031
ApplyCaller,
3132
CallableProxy,
@@ -2761,6 +2762,7 @@ def _to_refs(path, x):
27612762
node = graphlib.map(_to_refs, node, graph=graph)
27622763
return node
27632764

2765+
vars_as = deprecated(with_vars)
27642766

27652767
def as_pure(tree: A) -> A:
27662768
"""Returns a new tree with all ``Variable`` objects replaced with inner values.
@@ -2808,6 +2810,7 @@ def _pure_fn(_, x):
28082810

28092811
return map(_pure_fn, tree, auto_create_variables=False)
28102812

2813+
pure = deprecated(with_modules)
28112814

28122815
def call(
28132816
graphdef_state: tuple[GraphDef[A], GraphState], /

flax/nnx/module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from flax.nnx.pytreelib import Pytree, PytreeMeta
3030
from flax.nnx.graphlib import GraphState
3131
from flax.nnx.statelib import split_state, State
32+
from flax.nnx.deprecations import deprecated
3233
import functools as ft
3334
from flax.typing import Key, Path, PathParts
3435
from collections.abc import MutableMapping
@@ -515,6 +516,8 @@ def _set_mode_fn(path, node):
515516

516517
return out
517518

519+
view = deprecated(with_modules)
520+
518521
def with_attributes(
519522
node: A,
520523
/,

flax/nnx/spmd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import flax.core.spmd as core_spmd
1818
from flax.nnx import variablelib, graphlib
19+
from flax.nnx.deprecations import deprecated
1920
from flax.nnx.transforms.transforms import eval_shape
2021
from flax.typing import (
2122
Sharding,
@@ -238,3 +239,5 @@ def add_sharding(_path, x):
238239
return abs_var
239240
return x
240241
return graphlib.map(add_sharding, tree, graph=graph)
242+
243+
abstract_with_sharding = deprecated(as_abstract)

0 commit comments

Comments
 (0)