@@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
179179 type : type [Node ]
180180 flatten : tp .Callable [[Node ], tuple [tp .Sequence [tuple [Key , Leaf ]], AuxData ]]
181181
182- def node_dict (self , node : Node ) -> dict [Key , Leaf ]:
182+ def node_dict (self , node : Node ) -> dict [Key , tp . Any ]:
183183 nodes , _ = self .flatten (node )
184- return dict (nodes )
184+ nodes = {
185+ key : node .value if isinstance (node , DataElem | StaticElem ) else node
186+ for key , node in nodes
187+ }
188+ return nodes
185189
186190
187191@dataclasses .dataclass (frozen = True , slots = True )
@@ -524,32 +528,21 @@ def __treescope_repr__(self, path, subtree_renderer):
524528
525529
526530@dataclasses .dataclass (frozen = True , slots = True )
527- class ArrayAttr :
528- pass
529-
530-
531- ARRAY_ATTR = ArrayAttr ()
532-
533-
534- @dataclasses .dataclass (frozen = True , slots = True )
535- class MutableArrayAttr :
531+ class NodeAttr :
536532 pass
537533
538534
539- MUTABLE_ARRAY_ATTR = MutableArrayAttr ()
540-
535+ NODE_ATTR = NodeAttr ()
541536
542537@dataclasses .dataclass (frozen = True , slots = True )
543- class NodeAttr :
538+ class LeafAttr :
544539 pass
545540
546-
547- NODE_ATTR = NodeAttr ()
541+ LEAF_ATTR = LeafAttr ()
548542
549543AttrType = tp .Union [
550544 NodeAttr ,
551- ArrayAttr ,
552- MutableArrayAttr ,
545+ LeafAttr ,
553546 'Static[tp.Any]' ,
554547]
555548
@@ -711,6 +704,14 @@ def flatten( # type: ignore[invalid-annotation]
711704 else :
712705 return graphdef , leaves
713706
707+ @dataclasses .dataclass (frozen = True , slots = True )
708+ class DataElem :
709+ value : tp .Any
710+
711+
712+ @dataclasses .dataclass (frozen = True , slots = True )
713+ class StaticElem :
714+ value : tp .Any
714715
715716def _graph_flatten (
716717 node : Node ,
@@ -828,6 +829,18 @@ def make_mutable_arraydef(value: variablelib.Ref):
828829 nodes .append (nodedef )
829830
830831 for key , value in values :
832+ is_data = None
833+ if isinstance (value , DataElem ):
834+ value = value .value
835+ is_data = True
836+ elif isinstance (value , StaticElem ):
837+ value = value .value
838+ is_data = False
839+
840+ if is_data is False :
841+ attributes .append ((key , Static (value )))
842+ continue
843+
831844 value_node_impl = get_node_impl (value )
832845 if path is not None :
833846 path .append (key )
@@ -845,15 +858,15 @@ def make_mutable_arraydef(value: variablelib.Ref):
845858 paths ,
846859 )
847860 elif variablelib .is_array_ref (value ):
848- attributes .append ((key , MUTABLE_ARRAY_ATTR ))
861+ attributes .append ((key , NODE_ATTR ))
849862 array_refdef , leaf = make_mutable_arraydef (value )
850863 if not isinstance (leaf , Repeated ):
851864 leaves .append (leaf )
852865 if paths is not None :
853866 paths .append (tuple (path )) # type: ignore
854867 nodes .append (array_refdef )
855- elif isinstance (value , (jax .Array , np .ndarray )):
856- attributes .append ((key , ARRAY_ATTR ))
868+ elif isinstance (value , (jax .Array , np .ndarray )) or is_data :
869+ attributes .append ((key , LEAF_ATTR ))
857870 if paths is not None :
858871 paths .append (tuple (path )) # type: ignore
859872 leaves .append (value )
@@ -1093,41 +1106,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
10931106 key , value = next (attribute_iter )
10941107 if type (value ) is Static :
10951108 children .append ((key , value .value )) # type: ignore[attribute-error]
1096- elif type (value ) is MutableArrayAttr :
1097- array_refdef = next (node_iter )
1098- assert (
1099- type (array_refdef ) is ArrayRefDef or type (array_refdef ) is NodeRef
1100- )
1101- if type (array_refdef ) is NodeRef :
1102- array_ref = index_ref [array_refdef .index ]
1103- else :
1104- assert type (array_refdef ) is ArrayRefDef
1109+ elif type (value ) is LeafAttr :
1110+ leaf = next (leaves_iter )
1111+ children .append ((key , leaf ))
1112+ elif type (value ) is NodeAttr :
1113+ node_def = next (node_iter )
1114+ if isinstance (node_def , NodeRef ):
1115+ node = index_ref [node_def .index ]
1116+ elif isinstance (node_def , ArrayRefDef ):
11051117 leaf = next (leaves_iter )
1106- array_ref = get_mutable_array (array_refdef , leaf )
1107- children .append ((key , array_ref ))
1108- elif type (value ) is ArrayAttr :
1109- array = next (leaves_iter )
1110- children .append ((key , array ))
1118+ node = get_mutable_array (node_def , leaf )
1119+ elif isinstance (node_def , NodeDef | VariableDef ):
1120+ value_node_impl = get_node_impl_for_type (node_def .type )
1121+ node = _graph_unflatten (
1122+ node_def ,
1123+ value_node_impl ,
1124+ node_iter ,
1125+ attribute_iter ,
1126+ leaves_iter ,
1127+ index_ref ,
1128+ outer_index_outer_ref ,
1129+ copy_variables ,
1130+ )
1131+ else :
1132+ raise RuntimeError (f'Unknown node definition: { node_def !r} ' )
1133+ children .append ((key , node ))
11111134 elif type (value ) is NodeRef :
11121135 children .append ((key , index_ref [value .index ])) # type: ignore[attribute-error]
1113- elif type (value ) is NodeAttr :
1114- # if the key is a subgraph we create an empty node
1115- subgraphdef = next (node_iter )
1116- if type (subgraphdef ) is NodeDef :
1117- value_node_impl = get_node_impl_for_type (subgraphdef .type ) # type: ignore[attribute-error]
1118- else :
1119- value_node_impl = None
1120- subnode = _graph_unflatten (
1121- subgraphdef ,
1122- value_node_impl ,
1123- node_iter ,
1124- attribute_iter ,
1125- leaves_iter ,
1126- index_ref ,
1127- outer_index_outer_ref ,
1128- copy_variables ,
1129- )
1130- children .append ((key , subnode ))
11311136 else :
11321137 raise RuntimeError (f'Unknown static field: { key !r} ' )
11331138
0 commit comments