Skip to content

Commit a1e3bbf

Browse files
committed
[nnx] refactor UpdateContext
1 parent 8442dc8 commit a1e3bbf

File tree

4 files changed

+82
-261
lines changed

4 files changed

+82
-261
lines changed

flax/nnx/extract.py

+2-88
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import typing as tp
1717

1818
import jax
19-
# from jax._src.tree_util import broadcast_prefix
2019

2120
from flax import struct
2221
from flax.nnx.object import Object
@@ -32,91 +31,6 @@
3231
Leaf = tp.Any
3332

3433

35-
class ExtractionIndex(struct.PyTreeNode):
36-
"""Index of a graph node in a Pytree structure."""
37-
38-
index: Index = struct.field(pytree_node=False)
39-
40-
41-
@tp.overload
42-
def extract_graph_nodes(
43-
pytree: A,
44-
/,
45-
*,
46-
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
47-
) -> tuple[A, tuple[tp.Any, ...]]: ...
48-
@tp.overload
49-
def extract_graph_nodes(
50-
pytree: A,
51-
/,
52-
*,
53-
prefix: tp.Any,
54-
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
55-
) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ...
56-
def extract_graph_nodes(
57-
pytree: A,
58-
/,
59-
*,
60-
prefix: tp.Any = Missing,
61-
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
62-
) -> (
63-
tuple[A, tuple[tp.Any, ...]]
64-
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
65-
):
66-
"""Extracts all graph nodes from a pytree."""
67-
nodes: dict[tp.Any, Index] = {}
68-
node_prefixes = []
69-
leaves = []
70-
71-
prefix_leaves = broadcast_prefix(
72-
prefix,
73-
pytree,
74-
prefix_is_leaf=lambda x: x is None,
75-
)
76-
key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree)
77-
78-
assert len(key_leaves) == len(prefix_leaves)
79-
80-
for (keypath, leaf), prefix_leaf in zip(key_leaves, prefix_leaves):
81-
if validate_fn:
82-
validate_fn(keypath, prefix_leaf, leaf)
83-
if graph.is_graph_node(leaf):
84-
if leaf not in nodes:
85-
index = nodes[leaf] = len(nodes)
86-
node_prefixes.append(prefix_leaf)
87-
else:
88-
index = nodes[leaf]
89-
# check consistent aliasing
90-
if prefix_leaf != node_prefixes[index]:
91-
path_str = jax.tree_util.keystr(keypath)
92-
raise ValueError(
93-
f'Inconsistent aliasing detected. Node {type(leaf)} at path {path_str} '
94-
f'has different prefixes: {prefix_leaf} and {node_prefixes[index]}.'
95-
)
96-
leaves.append(ExtractionIndex(index))
97-
else:
98-
leaves.append(leaf)
99-
100-
pytree_out = jax.tree.unflatten(treedef, leaves)
101-
102-
if prefix is Missing:
103-
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
104-
else:
105-
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]
106-
107-
108-
def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
109-
"""Inserts graph nodes into a pytree."""
110-
111-
def _maybe_insert(x):
112-
if isinstance(x, ExtractionIndex):
113-
return nodes[x.index]
114-
return x
115-
116-
return jax.tree.map(
117-
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex)
118-
)
119-
12034
class PrefixMapping(abc.ABC):
12135
@abc.abstractmethod
12236
def map_prefix(
@@ -342,7 +256,7 @@ def from_tree(
342256
) -> tp.Any:
343257
if prefix is Missing or prefix is None:
344258
# fast path, no need for prefix broadcasting or consistent aliasing checks
345-
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
259+
with graph.merge_context(ctxtag, is_inner) as merge_ctx:
346260

347261
def maybe_split(x):
348262
if (
@@ -366,7 +280,7 @@ def maybe_split(x):
366280
assert len(leaf_keys) == len(leaf_prefixes)
367281
leaves_out = []
368282

369-
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
283+
with graph.merge_context(ctxtag, is_inner) as merge_ctx:
370284
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
371285
if (
372286
map_non_graph_nodes

flax/nnx/graph.py

+23-151
Original file line numberDiff line numberDiff line change
@@ -1633,9 +1633,9 @@ def unflatten(
16331633
def merge_context(): ...
16341634
@tp.overload
16351635
@contextlib.contextmanager
1636-
def merge_context(inner: bool | None, ctxtag: tp.Hashable | None): ...
1636+
def merge_context(ctxtag: tp.Hashable | None, inner: bool | None): ...
16371637
@contextlib.contextmanager
1638-
def merge_context(inner: bool | None = None, ctxtag: tp.Hashable | None = None):
1638+
def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None):
16391639
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner))
16401640

16411641
try:
@@ -1691,138 +1691,6 @@ def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
16911691
(obj, index) for index, obj in index_ref.items()
16921692
)
16931693

1694-
@tp.overload
1695-
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]:
1696-
...
1697-
1698-
@tp.overload
1699-
def split(
1700-
self, graph_node: A, first: filterlib.Filter, /
1701-
) -> tuple[GraphDef[A], GraphState]:
1702-
...
1703-
1704-
@tp.overload
1705-
def split(
1706-
self,
1707-
graph_node: A,
1708-
first: filterlib.Filter,
1709-
second: filterlib.Filter,
1710-
/,
1711-
*filters: filterlib.Filter,
1712-
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
1713-
...
1714-
1715-
def split(
1716-
self, node: A, *filters: filterlib.Filter
1717-
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
1718-
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
1719-
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1720-
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1721-
to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1722-
seamlessly between stateful and stateless representations of the graph.
1723-
1724-
Example usage::
1725-
1726-
>>> from flax import nnx
1727-
>>> import jax, jax.numpy as jnp
1728-
...
1729-
>>> class Foo(nnx.Module):
1730-
... def __init__(self, rngs):
1731-
... self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
1732-
... self.linear = nnx.Linear(2, 3, rngs=rngs)
1733-
...
1734-
>>> node = Foo(nnx.Rngs(0))
1735-
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
1736-
...
1737-
>>> jax.tree.map(jnp.shape, params)
1738-
State({
1739-
'batch_norm': {
1740-
'bias': VariableState(
1741-
type=Param,
1742-
value=(2,)
1743-
),
1744-
'scale': VariableState(
1745-
type=Param,
1746-
value=(2,)
1747-
)
1748-
},
1749-
'linear': {
1750-
'bias': VariableState(
1751-
type=Param,
1752-
value=(3,)
1753-
),
1754-
'kernel': VariableState(
1755-
type=Param,
1756-
value=(2, 3)
1757-
)
1758-
}
1759-
})
1760-
>>> jax.tree.map(jnp.shape, batch_stats)
1761-
State({
1762-
'batch_norm': {
1763-
'mean': VariableState(
1764-
type=BatchStat,
1765-
value=(2,)
1766-
),
1767-
'var': VariableState(
1768-
type=BatchStat,
1769-
value=(2,)
1770-
)
1771-
}
1772-
})
1773-
1774-
Arguments:
1775-
node: graph node to split.
1776-
*filters: some optional filters to group the state into mutually exclusive substates.
1777-
Returns:
1778-
:class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
1779-
filters are passed, a single :class:`State` is returned.
1780-
"""
1781-
ref_index: RefMap = RefMap()
1782-
graphdef, flat_state = flatten(
1783-
node, ref_index=ref_index, ref_outer_index=self.inner_ref_outer_index
1784-
)
1785-
flat_states = _split_state(flat_state, filters)
1786-
states = _to_nested_state(graphdef, flat_states)
1787-
assert len(states) >= 1
1788-
self.flatten_end(ref_index)
1789-
return graphdef, *states # type: ignore[return-value]
1790-
1791-
def merge(
1792-
self,
1793-
graphdef: GraphDef[A],
1794-
state: GraphState,
1795-
*states: GraphState,
1796-
) -> A:
1797-
"""merge"""
1798-
if not isinstance(graphdef, NodeDef):
1799-
raise ValueError(
1800-
f'Expected a NodeDef instance, but got {type(graphdef)}.'
1801-
)
1802-
if self.outer_ref_outer_index is None:
1803-
raise ValueError('Cannot merge without ref_index.')
1804-
1805-
if self.outer_ref_outer_index is not None:
1806-
# outer merge (4), create index_ref_cache
1807-
index_ref_cache = self.outer_index_outer_ref
1808-
assert index_ref_cache is not None
1809-
else:
1810-
# inner merge (2)
1811-
index_ref_cache = None
1812-
1813-
_state = _merge_to_flat_state((state, *states))
1814-
index_ref: dict[Index, tp.Any] = {}
1815-
node = unflatten(
1816-
graphdef,
1817-
_state,
1818-
index_ref=index_ref,
1819-
outer_index_outer_ref=index_ref_cache,
1820-
)
1821-
1822-
self.unflatten_end(index_ref, True)
1823-
1824-
return node
1825-
18261694

18271695
jax.tree_util.register_static(UpdateContext)
18281696

@@ -1919,16 +1787,20 @@ def update_context(tag: tp.Hashable):
19191787
>>> from flax import nnx
19201788
...
19211789
>>> m1 = nnx.Dict({})
1922-
>>> with nnx.update_context('example') as ctx:
1923-
... graphdef, state = ctx.split(m1)
1790+
>>> with nnx.update_context('example'):
1791+
... with nnx.split_context('example') as ctx:
1792+
... graphdef, state = ctx.split(m1)
19241793
... @jax.jit
19251794
... def f(graphdef, state):
1926-
... m2 = ctx.merge(graphdef, state)
1795+
... with nnx.merge_context('example', inner=True) as ctx:
1796+
... m2 = ctx.merge(graphdef, state)
19271797
... m2.a = 1
19281798
... m2.ref = m2 # create a reference cycle
1929-
... return ctx.split(m2)
1799+
... with nnx.split_context('example') as ctx:
1800+
... return ctx.split(m2)
19301801
... graphdef_out, state_out = f(graphdef, state)
1931-
... m3 = ctx.merge(graphdef_out, state_out)
1802+
... with nnx.merge_context('example', inner=False) as ctx:
1803+
... m3 = ctx.merge(graphdef_out, state_out)
19321804
...
19331805
>>> assert m1 is m3
19341806
>>> assert m1.a == 1
@@ -1937,36 +1809,36 @@ def update_context(tag: tp.Hashable):
19371809
Note that ``update_context`` takes in a ``tag`` argument which is used
19381810
primarily as a safety mechanism reduce the risk of accidentally using the
19391811
wrong UpdateContext when using :func:`current_update_context` to access the
1940-
current active context. current_update_context can be used as a way of
1941-
accessing the current active context without having to pass it as a capture::
1812+
current active context. ``update_context`` can also be used as a
1813+
decorator that creates/activates an UpdateContext context for the
1814+
duration of the function::
19421815
19431816
>>> from flax import nnx
19441817
...
19451818
>>> m1 = nnx.Dict({})
19461819
>>> @jax.jit
19471820
... def f(graphdef, state):
1948-
... ctx = nnx.current_update_context('example')
1949-
... m2 = ctx.merge(graphdef, state)
1821+
... with nnx.merge_context('example', inner=True) as ctx:
1822+
... m2 = ctx.merge(graphdef, state)
19501823
... m2.a = 1 # insert static attribute
19511824
... m2.ref = m2 # create a reference cycle
1952-
... return ctx.split(m2)
1825+
... with nnx.split_context('example') as ctx:
1826+
... return ctx.split(m2)
19531827
...
19541828
>>> @nnx.update_context('example')
19551829
... def g(m1):
1956-
... ctx = nnx.current_update_context('example')
1957-
... graphdef, state = ctx.split(m1)
1830+
... with nnx.split_context('example') as ctx:
1831+
... graphdef, state = ctx.split(m1)
19581832
... graphdef_out, state_out = f(graphdef, state)
1959-
... return ctx.merge(graphdef_out, state_out)
1833+
... with nnx.merge_context('example', inner=False) as ctx:
1834+
... return ctx.merge(graphdef_out, state_out)
19601835
...
19611836
>>> m3 = g(m1)
19621837
>>> assert m1 is m3
19631838
>>> assert m1.a == 1
19641839
>>> assert m1.ref is m1
19651840
1966-
As shown in the code above, ``update_context`` can also be used as a
1967-
decorator that creates/activates an UpdateContext context for the
1968-
duration of the function. The context can be accessed using
1969-
:func:`current_update_context`.
1841+
The context can be accessed using :func:`current_update_context`.
19701842
19711843
Args:
19721844
tag: A string tag to identify the context.

tests/nnx/graph_utils_test.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,25 @@ def test_split_merge_context(self):
592592
self.assertFalse(hasattr(ctx, 'index_ref'))
593593
self.assertFalse(hasattr(ctx, 'ctxtag'))
594594

595+
def test_split_merge_context_example(self):
596+
m1 = nnx.Dict({})
597+
with nnx.update_context('example'):
598+
with nnx.split_context('example') as ctx:
599+
graphdef, state = ctx.split(m1)
600+
601+
@jax.jit
602+
def f(graphdef, state):
603+
with nnx.merge_context('example', True) as ctx:
604+
m2 = ctx.merge(graphdef, state)
605+
m2.a = 1
606+
m2.ref = m2 # create a reference cycle
607+
with nnx.split_context('example') as ctx:
608+
return ctx.split(m2)
609+
610+
graphdef_out, state_out = f(graphdef, state)
611+
with nnx.merge_context('example', False) as ctx:
612+
m3 = ctx.merge(graphdef_out, state_out)
613+
595614
def test_split_merge_context_nested(self):
596615
m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
597616
m1 = nnx.Sequential(m2)
@@ -636,7 +655,7 @@ def __init__(self):
636655

637656
@jax.jit
638657
def f(graphdef1, state1, graphdef2, state2):
639-
with nnx.graph.merge_context(True, ctxtag) as ctx:
658+
with nnx.graph.merge_context(ctxtag, True) as ctx:
640659
m1 = ctx.merge(graphdef1, state1)
641660
m2 = ctx.merge(graphdef2, state2)
642661

@@ -657,7 +676,7 @@ def f(graphdef1, state1, graphdef2, state2):
657676
graphdef1, state1, graphdef2, state2
658677
)
659678

660-
with nnx.graph.merge_context(False, ctxtag) as ctx:
679+
with nnx.graph.merge_context(ctxtag, False) as ctx:
661680
m1_out = ctx.merge(graphdef1, state1)
662681
m2_out = ctx.merge(graphdef2, state2)
663682

0 commit comments

Comments
 (0)