@@ -1633,9 +1633,9 @@ def unflatten(
1633
1633
def merge_context (): ...
1634
1634
@tp .overload
1635
1635
@contextlib .contextmanager
1636
- def merge_context (inner : bool | None , ctxtag : tp . Hashable | None ): ...
1636
+ def merge_context (ctxtag : tp . Hashable | None , inner : bool | None ): ...
1637
1637
@contextlib .contextmanager
1638
- def merge_context (inner : bool | None = None , ctxtag : tp . Hashable | None = None ):
1638
+ def merge_context (ctxtag : tp . Hashable | None = None , inner : bool | None = None ):
1639
1639
GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , {}, inner ))
1640
1640
1641
1641
try :
@@ -1691,138 +1691,6 @@ def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
1691
1691
(obj , index ) for index , obj in index_ref .items ()
1692
1692
)
1693
1693
1694
- @tp .overload
1695
- def split (self , graph_node : A , / ) -> tuple [GraphDef [A ], GraphState ]:
1696
- ...
1697
-
1698
- @tp .overload
1699
- def split (
1700
- self , graph_node : A , first : filterlib .Filter , /
1701
- ) -> tuple [GraphDef [A ], GraphState ]:
1702
- ...
1703
-
1704
- @tp .overload
1705
- def split (
1706
- self ,
1707
- graph_node : A ,
1708
- first : filterlib .Filter ,
1709
- second : filterlib .Filter ,
1710
- / ,
1711
- * filters : filterlib .Filter ,
1712
- ) -> tuple [GraphDef [A ], GraphState , tpe .Unpack [tuple [GraphState , ...]]]:
1713
- ...
1714
-
1715
- def split (
1716
- self , node : A , * filters : filterlib .Filter
1717
- ) -> tuple [GraphDef [A ], GraphState , tpe .Unpack [tuple [GraphState , ...]]]:
1718
- """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
1719
- a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1720
- contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1721
- to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1722
- seamlessly between stateful and stateless representations of the graph.
1723
-
1724
- Example usage::
1725
-
1726
- >>> from flax import nnx
1727
- >>> import jax, jax.numpy as jnp
1728
- ...
1729
- >>> class Foo(nnx.Module):
1730
- ... def __init__(self, rngs):
1731
- ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
1732
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
1733
- ...
1734
- >>> node = Foo(nnx.Rngs(0))
1735
- >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
1736
- ...
1737
- >>> jax.tree.map(jnp.shape, params)
1738
- State({
1739
- 'batch_norm': {
1740
- 'bias': VariableState(
1741
- type=Param,
1742
- value=(2,)
1743
- ),
1744
- 'scale': VariableState(
1745
- type=Param,
1746
- value=(2,)
1747
- )
1748
- },
1749
- 'linear': {
1750
- 'bias': VariableState(
1751
- type=Param,
1752
- value=(3,)
1753
- ),
1754
- 'kernel': VariableState(
1755
- type=Param,
1756
- value=(2, 3)
1757
- )
1758
- }
1759
- })
1760
- >>> jax.tree.map(jnp.shape, batch_stats)
1761
- State({
1762
- 'batch_norm': {
1763
- 'mean': VariableState(
1764
- type=BatchStat,
1765
- value=(2,)
1766
- ),
1767
- 'var': VariableState(
1768
- type=BatchStat,
1769
- value=(2,)
1770
- )
1771
- }
1772
- })
1773
-
1774
- Arguments:
1775
- node: graph node to split.
1776
- *filters: some optional filters to group the state into mutually exclusive substates.
1777
- Returns:
1778
- :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
1779
- filters are passed, a single :class:`State` is returned.
1780
- """
1781
- ref_index : RefMap = RefMap ()
1782
- graphdef , flat_state = flatten (
1783
- node , ref_index = ref_index , ref_outer_index = self .inner_ref_outer_index
1784
- )
1785
- flat_states = _split_state (flat_state , filters )
1786
- states = _to_nested_state (graphdef , flat_states )
1787
- assert len (states ) >= 1
1788
- self .flatten_end (ref_index )
1789
- return graphdef , * states # type: ignore[return-value]
1790
-
1791
- def merge (
1792
- self ,
1793
- graphdef : GraphDef [A ],
1794
- state : GraphState ,
1795
- * states : GraphState ,
1796
- ) -> A :
1797
- """merge"""
1798
- if not isinstance (graphdef , NodeDef ):
1799
- raise ValueError (
1800
- f'Expected a NodeDef instance, but got { type (graphdef )} .'
1801
- )
1802
- if self .outer_ref_outer_index is None :
1803
- raise ValueError ('Cannot merge without ref_index.' )
1804
-
1805
- if self .outer_ref_outer_index is not None :
1806
- # outer merge (4), create index_ref_cache
1807
- index_ref_cache = self .outer_index_outer_ref
1808
- assert index_ref_cache is not None
1809
- else :
1810
- # inner merge (2)
1811
- index_ref_cache = None
1812
-
1813
- _state = _merge_to_flat_state ((state , * states ))
1814
- index_ref : dict [Index , tp .Any ] = {}
1815
- node = unflatten (
1816
- graphdef ,
1817
- _state ,
1818
- index_ref = index_ref ,
1819
- outer_index_outer_ref = index_ref_cache ,
1820
- )
1821
-
1822
- self .unflatten_end (index_ref , True )
1823
-
1824
- return node
1825
-
1826
1694
1827
1695
jax .tree_util .register_static (UpdateContext )
1828
1696
@@ -1919,16 +1787,20 @@ def update_context(tag: tp.Hashable):
1919
1787
>>> from flax import nnx
1920
1788
...
1921
1789
>>> m1 = nnx.Dict({})
1922
- >>> with nnx.update_context('example') as ctx:
1923
- ... graphdef, state = ctx.split(m1)
1790
+ >>> with nnx.update_context('example'):
1791
+ ... with nnx.split_context('example') as ctx:
1792
+ ... graphdef, state = ctx.split(m1)
1924
1793
... @jax.jit
1925
1794
... def f(graphdef, state):
1926
- ... m2 = ctx.merge(graphdef, state)
1795
+ ... with nnx.merge_context('example', inner=True) as ctx:
1796
+ ... m2 = ctx.merge(graphdef, state)
1927
1797
... m2.a = 1
1928
1798
... m2.ref = m2 # create a reference cycle
1929
- ... return ctx.split(m2)
1799
+ ... with nnx.split_context('example') as ctx:
1800
+ ... return ctx.split(m2)
1930
1801
... graphdef_out, state_out = f(graphdef, state)
1931
- ... m3 = ctx.merge(graphdef_out, state_out)
1802
+ ... with nnx.merge_context('example', inner=False) as ctx:
1803
+ ... m3 = ctx.merge(graphdef_out, state_out)
1932
1804
...
1933
1805
>>> assert m1 is m3
1934
1806
>>> assert m1.a == 1
@@ -1937,36 +1809,36 @@ def update_context(tag: tp.Hashable):
1937
1809
Note that ``update_context`` takes in a ``tag`` argument which is used
1938
1810
primarily as a safety mechanism reduce the risk of accidentally using the
1939
1811
wrong UpdateContext when using :func:`current_update_context` to access the
1940
- current active context. current_update_context can be used as a way of
1941
- accessing the current active context without having to pass it as a capture::
1812
+ current active context. ``update_context`` can also be used as a
1813
+ decorator that creates/activates an UpdateContext context for the
1814
+ duration of the function::
1942
1815
1943
1816
>>> from flax import nnx
1944
1817
...
1945
1818
>>> m1 = nnx.Dict({})
1946
1819
>>> @jax.jit
1947
1820
... def f(graphdef, state):
1948
- ... ctx = nnx.current_update_context ('example')
1949
- ... m2 = ctx.merge(graphdef, state)
1821
+ ... with nnx.merge_context ('example', inner=True) as ctx:
1822
+ ... m2 = ctx.merge(graphdef, state)
1950
1823
... m2.a = 1 # insert static attribute
1951
1824
... m2.ref = m2 # create a reference cycle
1952
- ... return ctx.split(m2)
1825
+ ... with nnx.split_context('example') as ctx:
1826
+ ... return ctx.split(m2)
1953
1827
...
1954
1828
>>> @nnx.update_context('example')
1955
1829
... def g(m1):
1956
- ... ctx = nnx.current_update_context ('example')
1957
- ... graphdef, state = ctx.split(m1)
1830
+ ... with nnx.split_context ('example') as ctx:
1831
+ ... graphdef, state = ctx.split(m1)
1958
1832
... graphdef_out, state_out = f(graphdef, state)
1959
- ... return ctx.merge(graphdef_out, state_out)
1833
+ ... with nnx.merge_context('example', inner=False) as ctx:
1834
+ ... return ctx.merge(graphdef_out, state_out)
1960
1835
...
1961
1836
>>> m3 = g(m1)
1962
1837
>>> assert m1 is m3
1963
1838
>>> assert m1.a == 1
1964
1839
>>> assert m1.ref is m1
1965
1840
1966
- As shown in the code above, ``update_context`` can also be used as a
1967
- decorator that creates/activates an UpdateContext context for the
1968
- duration of the function. The context can be accessed using
1969
- :func:`current_update_context`.
1841
+ The context can be accessed using :func:`current_update_context`.
1970
1842
1971
1843
Args:
1972
1844
tag: A string tag to identify the context.
0 commit comments