Skip to content

Commit 4047574

Browse files
committed
[nnx] add more flaxlib types
1 parent 437cba3 commit 4047574

File tree

8 files changed

+477
-201
lines changed

8 files changed

+477
-201
lines changed

benchmarks/nnx_simple_training.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ def test_step_nnx(model: MLP, batch):
116116
loss = jnp.mean((y - y_pred) ** 2)
117117
return {'loss': loss}
118118

119-
cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120-
cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
119+
# cached_train_step_nnx = nnx.cached_partial(train_step_nnx, model, optimizer)
120+
# cached_test_step_nnx = nnx.cached_partial(test_step_nnx, model)
121121

122122
for step, batch in enumerate(dataset(X, Y, batch_size)):
123-
cached_train_step_nnx(batch)
123+
train_step_nnx(model, optimizer, batch)
124124

125125
if step % 1000 == 0:
126-
logs = cached_test_step_nnx((X, Y))
126+
logs = test_step_nnx(model, (X, Y))
127127

128128
if step >= total_steps - 1:
129129
break

flax/nnx/graph.py

+79-40
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import threading
2121
import typing as tp
2222

23+
from flax import config
2324
from flax.nnx import filterlib, reprlib, traversals, variablelib
2425
from flax.nnx import statelib
2526
from flax.nnx.proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
6364
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
6465
return isinstance(x, Variable)
6566

67+
class IndexMap(dict[Index, tp.Any]):
68+
@staticmethod
69+
def from_refmap(refmap: RefMap) -> IndexMap:
70+
indexmap = IndexMap()
71+
indexmap.update((index, value) for value, index in refmap.items())
72+
return indexmap
73+
74+
if config.flax_use_flaxlib:
75+
import flaxlib
76+
77+
globals()['IndexMap'] = flaxlib.IndexMap
78+
6679

6780
# RefMap = dict
68-
class RefMap(tp.MutableMapping[A, B]):
81+
class RefMap(tp.MutableMapping[tp.Any, int]):
6982
"""A mapping that hashes keys by their identity."""
7083

7184
def __init__(
72-
self,
73-
mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] | None = None,
74-
/,
85+
self,
86+
mapping: tp.Mapping[tp.Any, int]
87+
| tp.Iterable[tuple[tp.Any, int]]
88+
| None = None,
89+
/,
7590
):
76-
self._mapping: dict[int, tuple[A, B]] = dict()
91+
self._mapping: dict[int, tuple[tp.Any, int]] = dict()
7792
if mapping is not None:
7893
self.update(mapping)
7994

80-
def __getitem__(self, key: A) -> B:
95+
@staticmethod
96+
def from_indexmap(indexmap: IndexMap) -> RefMap:
97+
refmap = RefMap()
98+
refmap.update((value, index) for index, value in indexmap.items())
99+
return refmap
100+
101+
def __getitem__(self, key: tp.Any) -> int:
81102
return self._mapping[id(key)][1]
82103

83-
def __setitem__(self, key: A, value: B):
104+
def __setitem__(self, key: tp.Any, value: int):
84105
self._mapping[id(key)] = (key, value)
85106

86-
def __delitem__(self, key: A):
107+
def __delitem__(self, key: tp.Any):
87108
del self._mapping[id(key)]
88109

89110
def __len__(self) -> int:
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92113
def __contains__(self, key: tp.Any) -> bool:
93114
return id(key) in self._mapping
94115

95-
def __iter__(self) -> tp.Iterator[A]:
116+
def __iter__(self) -> tp.Iterator[tp.Any]:
96117
for key, _ in self._mapping.values():
97118
yield key
98119

99-
def items(self) -> tp.ItemsView[A, B]:
120+
def items(self) -> tp.ItemsView[tp.Any, int]:
100121
return self._mapping.values() # type: ignore
101122

102123

124+
if config.flax_use_flaxlib:
125+
import flaxlib
126+
127+
globals()['RefMap'] = flaxlib.RefMap
128+
129+
103130
@dataclasses.dataclass(frozen=True, slots=True)
104131
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
105132
type: type[Node]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258285
subtree_renderer=subtree_renderer,
259286
)
260287

288+
if config.flax_use_flaxlib:
289+
import flaxlib
290+
291+
jax.tree_util.register_static(flaxlib.NodeRef)
292+
globals()['NodeRef'] = flaxlib.NodeRef
261293

262294
@jax.tree_util.register_static
263295
@dataclasses.dataclass(frozen=True, repr=False)
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299331
subtree_renderer=subtree_renderer,
300332
)
301333

334+
if config.flax_use_flaxlib:
335+
import flaxlib
336+
337+
jax.tree_util.register_static(flaxlib.VariableDef)
338+
globals()['VariableDef'] = flaxlib.VariableDef
302339

303340
@jax.tree_util.register_static
304341
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331368
metadata=self.metadata,
332369
)
333370

334-
def replace(self, **kwargs):
335-
return dataclasses.replace(self, **kwargs)
336-
337371
def __nnx_repr__(self):
338372
yield reprlib.Object(type=type(self))
339373

@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358392
)
359393

360394

395+
if config.flax_use_flaxlib:
396+
import flaxlib
397+
398+
jax.tree_util.register_static(flaxlib.NodeDef)
399+
globals()['NodeDef'] = flaxlib.NodeDef
400+
401+
361402
@jax.tree_util.register_static
362403
@dataclasses.dataclass(frozen=True, slots=True)
363404
class ArrayAttr:
@@ -548,7 +589,7 @@ def _graph_flatten(
548589
node: Node,
549590
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
550591
path: list[Key] | None,
551-
ref_index: RefMap[tp.Any, int],
592+
ref_index: RefMap,
552593
ref_outer_index: RefMap | None,
553594
nodes: list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]],
554595
attributes: list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]],
@@ -599,13 +640,13 @@ def _graph_flatten(
599640
values, metadata = node_impl.flatten(node)
600641
num_attributes = len(values)
601642
nodedef = NodeDef(
602-
type=node_impl.type,
603-
index=index,
604-
outer_index=ref_outer_index[node]
643+
node_impl.type,
644+
index,
645+
ref_outer_index[node]
605646
if is_graph_node_ and ref_outer_index and node in ref_outer_index
606647
else None,
607-
num_attributes=num_attributes,
608-
metadata=metadata,
648+
num_attributes,
649+
metadata,
609650
)
610651
nodes.append(nodedef)
611652

@@ -865,8 +906,8 @@ def unflatten(
865906
state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any],
866907
/,
867908
*,
868-
index_ref: dict[Index, tp.Any] | None = None,
869-
outer_index_outer_ref: dict[Index, tp.Any] | None = None,
909+
index_ref: IndexMap | None = None,
910+
outer_index_outer_ref: IndexMap | None = None,
870911
) -> Node:
871912
"""Unflattens a graphdef into a node with the given state.
872913
@@ -892,7 +933,7 @@ def unflatten(
892933
else:
893934
raise ValueError(f'Unsupported state type: {type(state)}')
894935
if index_ref is None:
895-
index_ref = {}
936+
index_ref = IndexMap()
896937

897938
if len(leaves) != graphdef.num_leaves:
898939
raise ValueError(
@@ -936,8 +977,8 @@ def _graph_unflatten(
936977
tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]
937978
],
938979
leaves_iter: tp.Iterator[tp.Any],
939-
index_ref: dict[Index, tp.Any],
940-
outer_index_outer_ref: dict[Index, tp.Any] | None,
980+
index_ref: IndexMap,
981+
outer_index_outer_ref: IndexMap | None,
941982
) -> Node:
942983
"""Recursive helper for graph_unflatten.
943984
@@ -1001,7 +1042,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
10011042
assert type(nodedef) is NodeDef
10021043
if node_impl is None:
10031044
raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.')
1004-
if nodedef.index in index_ref:
1045+
if nodedef.index is not None and nodedef.index in index_ref:
10051046
raise RuntimeError(f'GraphDef index {nodedef.index} already used.')
10061047

10071048
def _get_children() -> list[tuple[Key, tp.Any]]:
@@ -1214,7 +1255,7 @@ class StaticCache(tp.NamedTuple):
12141255
paths: tuple[PathParts, ...]
12151256
variables: list[Variable[tp.Any]]
12161257
new_ref_index: RefMap
1217-
new_index_ref: dict[Index, tp.Any]
1258+
new_index_ref: IndexMap
12181259

12191260
@staticmethod
12201261
def create(
@@ -1223,7 +1264,7 @@ def create(
12231264
variables: list[Variable[tp.Any]],
12241265
new_ref_index: RefMap,
12251266
):
1226-
new_index_ref = {index: obj for obj, index in new_ref_index.items()}
1267+
new_index_ref = IndexMap.from_refmap(new_ref_index)
12271268
final_graphdef: GraphDef[tp.Any]
12281269
final_graphdef = graphdef.with_same_outer_index()
12291270
return StaticCache(
@@ -1243,15 +1284,15 @@ class GraphContext(threading.local):
12431284
)
12441285
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
12451286
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
1246-
tmp_static_cache: RefMap[tp.Any, StaticCache] | None = None
1287+
tmp_static_cache: RefMap | None = None
12471288
caching: bool = False
12481289

12491290

12501291
GRAPH_CONTEXT = GraphContext()
12511292

12521293

12531294
@contextlib.contextmanager
1254-
def static_cache(static_cache: RefMap[tp.Any, StaticCache]):
1295+
def static_cache(static_cache: RefMap):
12551296
if GRAPH_CONTEXT.caching:
12561297
yield
12571298
return
@@ -1314,9 +1355,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
13141355
Returns:
13151356
A partial function expecting the remaining arguments to the original function.
13161357
"""
1317-
cache: RefMap[tp.Any, StaticCache] = RefMap()
1358+
cache: RefMap = RefMap()
13181359
original_ref_index: RefMap = RefMap()
1319-
index_ref: dict[Index, tp.Any] = {}
1360+
index_ref: IndexMap = IndexMap()
13201361
cached_ref_index: RefMap = RefMap()
13211362

13221363
def create_static_cache(x):
@@ -1542,7 +1583,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
15421583
@dataclasses.dataclass
15431584
class MergeContext:
15441585
ctxtag: tp.Hashable | None
1545-
index_ref: dict[Index, tp.Any]
1586+
index_ref: IndexMap
15461587
is_inner: bool | None
15471588

15481589
def merge(
@@ -1668,7 +1709,7 @@ def merge_context(): ...
16681709
def merge_context(ctxtag: tp.Hashable | None, inner: bool | None): ...
16691710
@contextlib.contextmanager
16701711
def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None):
1671-
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner))
1712+
GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, IndexMap(), inner))
16721713

16731714
try:
16741715
yield GRAPH_CONTEXT.index_ref_stack[-1]
@@ -1691,11 +1732,11 @@ class UpdateContext:
16911732

16921733
tag: tp.Hashable
16931734
outer_ref_outer_index: RefMap | None
1694-
outer_index_inner_ref: dict[Index, tp.Any] | None
1735+
outer_index_inner_ref: IndexMap | None
16951736
# reverse caches
1696-
outer_index_outer_ref: dict[Index, tp.Any] | None
1737+
outer_index_outer_ref: IndexMap | None
16971738
inner_ref_outer_index: RefMap | None
1698-
static_cache: RefMap[tp.Any, StaticCache] | None
1739+
static_cache: RefMap | None
16991740

17001741
# define hash and eq to make this an opaque object
17011742
def __hash__(self):
@@ -1716,13 +1757,11 @@ def flatten_end(self, ref_index: RefMap):
17161757
self.outer_index_inner_ref = None
17171758
self.inner_ref_outer_index = None
17181759

1719-
def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
1760+
def unflatten_end(self, index_ref: IndexMap, inner_merge: bool):
17201761
if inner_merge:
17211762
# inner merge (2)
17221763
self.outer_index_inner_ref = index_ref
1723-
self.inner_ref_outer_index = RefMap(
1724-
(obj, index) for index, obj in index_ref.items()
1725-
)
1764+
self.inner_ref_outer_index = RefMap.from_indexmap(index_ref)
17261765

17271766

17281767
@dataclasses.dataclass

flaxlib_src/src/flaxlib/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
# limitations under the License.
1414

1515
from .flaxlib_cpp import RefMap as RefMap
16-
from .flaxlib_cpp import _graph_fingerprint as _graph_fingerprint
16+
from .flaxlib_cpp import IndexMap as IndexMap
17+
from .flaxlib_cpp import NodeDef as NodeDef
18+
from .flaxlib_cpp import VariableDef as VariableDef
19+
from .flaxlib_cpp import NodeRef as NodeRef

flaxlib_src/src/flaxlib/flaxlib_cpp.pyi

+47-7
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,51 @@
1515
import typing as tp
1616

1717
RefMap = tp.MutableMapping[tp.Any, int]
18+
IndexMap = dict[int, tp.Any]
19+
20+
class NodeDef:
21+
type: type
22+
index: int | None
23+
outer_index: int | None
24+
num_attributes: int
25+
metadata: tp.Any
26+
27+
def with_no_outer_index(self) -> NodeDef: ...
28+
def with_same_outer_index(self) -> NodeDef: ...
29+
def __eq__(self, other: tp.Any) -> bool: ...
30+
def __hash__(self) -> int: ...
31+
def __getstate__(
32+
self,
33+
) -> tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any]: ...
34+
@staticmethod
35+
def __setstate__(
36+
nodedef: NodeDef, state: tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any]
37+
) -> None: ...
38+
39+
class VariableDef:
40+
type: type
41+
index: int
42+
outer_index: int | None
43+
metadata: tp.Any
44+
45+
def with_no_outer_index(self) -> VariableDef: ...
46+
def with_same_outer_index(self) -> VariableDef: ...
47+
def __eq__(self, other: tp.Any) -> bool: ...
48+
def __hash__(self) -> int: ...
49+
def __getstate__(
50+
self,
51+
) -> tuple[tp.Any, int, tp.Any, tp.Any]: ...
52+
@staticmethod
53+
def __setstate__(
54+
variabledef: 'VariableDef', state: tuple[tp.Any, int, tp.Any, tp.Any]
55+
) -> None: ...
56+
57+
class NodeRef:
58+
index: int
59+
60+
def __eq__(self, other: tp.Any) -> bool: ...
61+
def __hash__(self) -> int: ...
62+
def __getstate__(self) -> tuple[int]: ...
63+
@staticmethod
64+
def __setstate__(noderef: NodeRef, state: tuple[int]) -> None: ...
1865

19-
def _graph_fingerprint(
20-
node,
21-
node_impl,
22-
ref_index: RefMap,
23-
new_ref_index: RefMap,
24-
next_index: int,
25-
) -> tuple[tuple[tp.Any, ...], int]: ...

0 commit comments

Comments
 (0)