20
20
import threading
21
21
import typing as tp
22
22
23
+ from flax import config
23
24
from flax .nnx import filterlib , reprlib , traversals , variablelib
24
25
from flax .nnx import statelib
25
26
from flax .nnx .proxy_caller import (
@@ -63,27 +64,47 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
63
64
def is_node_leaf (x : tp .Any ) -> tpe .TypeGuard [NodeLeaf ]:
64
65
return isinstance (x , Variable )
65
66
67
+ class IndexMap (dict [Index , tp .Any ]):
68
+ @staticmethod
69
+ def from_refmap (refmap : RefMap ) -> IndexMap :
70
+ indexmap = IndexMap ()
71
+ indexmap .update ((index , value ) for value , index in refmap .items ())
72
+ return indexmap
73
+
74
+ if config .flax_use_flaxlib :
75
+ import flaxlib
76
+
77
+ globals ()['IndexMap' ] = flaxlib .IndexMap
78
+
66
79
67
80
# RefMap = dict
68
- class RefMap (tp .MutableMapping [A , B ]):
81
+ class RefMap (tp .MutableMapping [tp . Any , int ]):
69
82
"""A mapping that hashes keys by their identity."""
70
83
71
84
def __init__ (
72
- self ,
73
- mapping : tp .Mapping [A , B ] | tp .Iterable [tuple [A , B ]] | None = None ,
74
- / ,
85
+ self ,
86
+ mapping : tp .Mapping [tp .Any , int ]
87
+ | tp .Iterable [tuple [tp .Any , int ]]
88
+ | None = None ,
89
+ / ,
75
90
):
76
- self ._mapping : dict [int , tuple [A , B ]] = dict ()
91
+ self ._mapping : dict [int , tuple [tp . Any , int ]] = dict ()
77
92
if mapping is not None :
78
93
self .update (mapping )
79
94
80
- def __getitem__ (self , key : A ) -> B :
95
+ @staticmethod
96
+ def from_indexmap (indexmap : IndexMap ) -> RefMap :
97
+ refmap = RefMap ()
98
+ refmap .update ((value , index ) for index , value in indexmap .items ())
99
+ return refmap
100
+
101
+ def __getitem__ (self , key : tp .Any ) -> int :
81
102
return self ._mapping [id (key )][1 ]
82
103
83
- def __setitem__ (self , key : A , value : B ):
104
+ def __setitem__ (self , key : tp . Any , value : int ):
84
105
self ._mapping [id (key )] = (key , value )
85
106
86
- def __delitem__ (self , key : A ):
107
+ def __delitem__ (self , key : tp . Any ):
87
108
del self ._mapping [id (key )]
88
109
89
110
def __len__ (self ) -> int :
@@ -92,14 +113,20 @@ def __len__(self) -> int:
92
113
def __contains__ (self , key : tp .Any ) -> bool :
93
114
return id (key ) in self ._mapping
94
115
95
- def __iter__ (self ) -> tp .Iterator [A ]:
116
+ def __iter__ (self ) -> tp .Iterator [tp . Any ]:
96
117
for key , _ in self ._mapping .values ():
97
118
yield key
98
119
99
- def items (self ) -> tp .ItemsView [A , B ]:
120
+ def items (self ) -> tp .ItemsView [tp . Any , int ]:
100
121
return self ._mapping .values () # type: ignore
101
122
102
123
124
+ if config .flax_use_flaxlib :
125
+ import flaxlib
126
+
127
+ globals ()['RefMap' ] = flaxlib .RefMap
128
+
129
+
103
130
@dataclasses .dataclass (frozen = True , slots = True )
104
131
class NodeImplBase (tp .Generic [Node , Leaf , AuxData ]):
105
132
type : type [Node ]
@@ -258,6 +285,11 @@ def __treescope_repr__(self, path, subtree_renderer):
258
285
subtree_renderer = subtree_renderer ,
259
286
)
260
287
288
+ if config .flax_use_flaxlib :
289
+ import flaxlib
290
+
291
+ jax .tree_util .register_static (flaxlib .NodeRef )
292
+ globals ()['NodeRef' ] = flaxlib .NodeRef
261
293
262
294
@jax .tree_util .register_static
263
295
@dataclasses .dataclass (frozen = True , repr = False )
@@ -299,6 +331,11 @@ def __treescope_repr__(self, path, subtree_renderer):
299
331
subtree_renderer = subtree_renderer ,
300
332
)
301
333
334
+ if config .flax_use_flaxlib :
335
+ import flaxlib
336
+
337
+ jax .tree_util .register_static (flaxlib .VariableDef )
338
+ globals ()['VariableDef' ] = flaxlib .VariableDef
302
339
303
340
@jax .tree_util .register_static
304
341
@dataclasses .dataclass (frozen = True , repr = False , slots = True )
@@ -331,9 +368,6 @@ def with_same_outer_index(self) -> NodeDef[Node]:
331
368
metadata = self .metadata ,
332
369
)
333
370
334
- def replace (self , ** kwargs ):
335
- return dataclasses .replace (self , ** kwargs )
336
-
337
371
def __nnx_repr__ (self ):
338
372
yield reprlib .Object (type = type (self ))
339
373
@@ -358,6 +392,13 @@ def __treescope_repr__(self, path, subtree_renderer):
358
392
)
359
393
360
394
395
+ if config .flax_use_flaxlib :
396
+ import flaxlib
397
+
398
+ jax .tree_util .register_static (flaxlib .NodeDef )
399
+ globals ()['NodeDef' ] = flaxlib .NodeDef
400
+
401
+
361
402
@jax .tree_util .register_static
362
403
@dataclasses .dataclass (frozen = True , slots = True )
363
404
class ArrayAttr :
@@ -548,7 +589,7 @@ def _graph_flatten(
548
589
node : Node ,
549
590
node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
550
591
path : list [Key ] | None ,
551
- ref_index : RefMap [ tp . Any , int ] ,
592
+ ref_index : RefMap ,
552
593
ref_outer_index : RefMap | None ,
553
594
nodes : list [NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ]],
554
595
attributes : list [tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]],
@@ -599,13 +640,13 @@ def _graph_flatten(
599
640
values , metadata = node_impl .flatten (node )
600
641
num_attributes = len (values )
601
642
nodedef = NodeDef (
602
- type = node_impl .type ,
603
- index = index ,
604
- outer_index = ref_outer_index [node ]
643
+ node_impl .type ,
644
+ index ,
645
+ ref_outer_index [node ]
605
646
if is_graph_node_ and ref_outer_index and node in ref_outer_index
606
647
else None ,
607
- num_attributes = num_attributes ,
608
- metadata = metadata ,
648
+ num_attributes ,
649
+ metadata ,
609
650
)
610
651
nodes .append (nodedef )
611
652
@@ -865,8 +906,8 @@ def unflatten(
865
906
state : State [Key , tp .Any ] | FlatState [tp .Any ] | list [tp .Any ],
866
907
/ ,
867
908
* ,
868
- index_ref : dict [ Index , tp . Any ] | None = None ,
869
- outer_index_outer_ref : dict [ Index , tp . Any ] | None = None ,
909
+ index_ref : IndexMap | None = None ,
910
+ outer_index_outer_ref : IndexMap | None = None ,
870
911
) -> Node :
871
912
"""Unflattens a graphdef into a node with the given state.
872
913
@@ -892,7 +933,7 @@ def unflatten(
892
933
else :
893
934
raise ValueError (f'Unsupported state type: { type (state )} ' )
894
935
if index_ref is None :
895
- index_ref = {}
936
+ index_ref = IndexMap ()
896
937
897
938
if len (leaves ) != graphdef .num_leaves :
898
939
raise ValueError (
@@ -936,8 +977,8 @@ def _graph_unflatten(
936
977
tuple [Key , NodeAttr | ArrayAttr | Static [tp .Any ]]
937
978
],
938
979
leaves_iter : tp .Iterator [tp .Any ],
939
- index_ref : dict [ Index , tp . Any ] ,
940
- outer_index_outer_ref : dict [ Index , tp . Any ] | None ,
980
+ index_ref : IndexMap ,
981
+ outer_index_outer_ref : IndexMap | None ,
941
982
) -> Node :
942
983
"""Recursive helper for graph_unflatten.
943
984
@@ -1001,7 +1042,7 @@ def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
1001
1042
assert type (nodedef ) is NodeDef
1002
1043
if node_impl is None :
1003
1044
raise RuntimeError (f'Unsupported type: { nodedef .type } , this is a bug.' )
1004
- if nodedef .index in index_ref :
1045
+ if nodedef .index is not None and nodedef . index in index_ref :
1005
1046
raise RuntimeError (f'GraphDef index { nodedef .index } already used.' )
1006
1047
1007
1048
def _get_children () -> list [tuple [Key , tp .Any ]]:
@@ -1214,7 +1255,7 @@ class StaticCache(tp.NamedTuple):
1214
1255
paths : tuple [PathParts , ...]
1215
1256
variables : list [Variable [tp .Any ]]
1216
1257
new_ref_index : RefMap
1217
- new_index_ref : dict [ Index , tp . Any ]
1258
+ new_index_ref : IndexMap
1218
1259
1219
1260
@staticmethod
1220
1261
def create (
@@ -1223,7 +1264,7 @@ def create(
1223
1264
variables : list [Variable [tp .Any ]],
1224
1265
new_ref_index : RefMap ,
1225
1266
):
1226
- new_index_ref = { index : obj for obj , index in new_ref_index . items ()}
1267
+ new_index_ref = IndexMap . from_refmap ( new_ref_index )
1227
1268
final_graphdef : GraphDef [tp .Any ]
1228
1269
final_graphdef = graphdef .with_same_outer_index ()
1229
1270
return StaticCache (
@@ -1243,15 +1284,15 @@ class GraphContext(threading.local):
1243
1284
)
1244
1285
ref_index_stack : list [SplitContext ] = dataclasses .field (default_factory = list )
1245
1286
index_ref_stack : list [MergeContext ] = dataclasses .field (default_factory = list )
1246
- tmp_static_cache : RefMap [ tp . Any , StaticCache ] | None = None
1287
+ tmp_static_cache : RefMap | None = None
1247
1288
caching : bool = False
1248
1289
1249
1290
1250
1291
GRAPH_CONTEXT = GraphContext ()
1251
1292
1252
1293
1253
1294
@contextlib .contextmanager
1254
- def static_cache (static_cache : RefMap [ tp . Any , StaticCache ] ):
1295
+ def static_cache (static_cache : RefMap ):
1255
1296
if GRAPH_CONTEXT .caching :
1256
1297
yield
1257
1298
return
@@ -1314,9 +1355,9 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
1314
1355
Returns:
1315
1356
A partial function expecting the remaining arguments to the original function.
1316
1357
"""
1317
- cache : RefMap [ tp . Any , StaticCache ] = RefMap ()
1358
+ cache : RefMap = RefMap ()
1318
1359
original_ref_index : RefMap = RefMap ()
1319
- index_ref : dict [ Index , tp . Any ] = {}
1360
+ index_ref : IndexMap = IndexMap ()
1320
1361
cached_ref_index : RefMap = RefMap ()
1321
1362
1322
1363
def create_static_cache (x ):
@@ -1542,7 +1583,7 @@ def split_context(ctxtag: tp.Hashable | None = None):
1542
1583
@dataclasses .dataclass
1543
1584
class MergeContext :
1544
1585
ctxtag : tp .Hashable | None
1545
- index_ref : dict [ Index , tp . Any ]
1586
+ index_ref : IndexMap
1546
1587
is_inner : bool | None
1547
1588
1548
1589
def merge (
@@ -1668,7 +1709,7 @@ def merge_context(): ...
1668
1709
def merge_context (ctxtag : tp .Hashable | None , inner : bool | None ): ...
1669
1710
@contextlib .contextmanager
1670
1711
def merge_context (ctxtag : tp .Hashable | None = None , inner : bool | None = None ):
1671
- GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , {} , inner ))
1712
+ GRAPH_CONTEXT .index_ref_stack .append (MergeContext (ctxtag , IndexMap () , inner ))
1672
1713
1673
1714
try :
1674
1715
yield GRAPH_CONTEXT .index_ref_stack [- 1 ]
@@ -1691,11 +1732,11 @@ class UpdateContext:
1691
1732
1692
1733
tag : tp .Hashable
1693
1734
outer_ref_outer_index : RefMap | None
1694
- outer_index_inner_ref : dict [ Index , tp . Any ] | None
1735
+ outer_index_inner_ref : IndexMap | None
1695
1736
# reverse caches
1696
- outer_index_outer_ref : dict [ Index , tp . Any ] | None
1737
+ outer_index_outer_ref : IndexMap | None
1697
1738
inner_ref_outer_index : RefMap | None
1698
- static_cache : RefMap [ tp . Any , StaticCache ] | None
1739
+ static_cache : RefMap | None
1699
1740
1700
1741
# define hash and eq to make this an opaque object
1701
1742
def __hash__ (self ):
@@ -1716,13 +1757,11 @@ def flatten_end(self, ref_index: RefMap):
1716
1757
self .outer_index_inner_ref = None
1717
1758
self .inner_ref_outer_index = None
1718
1759
1719
- def unflatten_end (self , index_ref : dict [ Index , tp . Any ] , inner_merge : bool ):
1760
+ def unflatten_end (self , index_ref : IndexMap , inner_merge : bool ):
1720
1761
if inner_merge :
1721
1762
# inner merge (2)
1722
1763
self .outer_index_inner_ref = index_ref
1723
- self .inner_ref_outer_index = RefMap (
1724
- (obj , index ) for index , obj in index_ref .items ()
1725
- )
1764
+ self .inner_ref_outer_index = RefMap .from_indexmap (index_ref )
1726
1765
1727
1766
1728
1767
@dataclasses .dataclass
0 commit comments