@@ -196,6 +196,8 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
196196@dataclasses .dataclass (frozen = True , slots = True )
197197class PytreeNodeImpl (NodeImplBase [Node , Leaf , AuxData ]):
198198 unflatten : tp .Callable [[tp .Sequence [tuple [Key , Leaf ]], AuxData ], Node ]
199+ set_key : tp .Callable [[Node , Key , Leaf ], None ] | None
200+ pop_key : tp .Callable [[Node , Key ], Leaf ] | None
199201
200202
201203NodeImpl = tp .Union [
@@ -234,12 +236,19 @@ def register_pytree_node_type(
234236 type : type ,
235237 flatten : tp .Callable [[Node ], tuple [tp .Sequence [tuple [Key , Leaf ]], AuxData ]],
236238 unflatten : tp .Callable [[tp .Sequence [tuple [Key , Leaf ]], AuxData ], Node ],
239+ * ,
240+ set_key : tp .Callable [[Node , Key , Leaf ], None ] | None = None ,
241+ pop_key : tp .Callable [[Node , Key ], Leaf ] | None = None ,
237242):
238243 if type in PYTREE_REGISTRY :
239244 raise ValueError (f'Node type { type } is already registered.' )
240245
241246 PYTREE_REGISTRY [type ] = PytreeNodeImpl (
242- type = type , flatten = flatten , unflatten = unflatten
247+ type = type ,
248+ flatten = flatten ,
249+ unflatten = unflatten ,
250+ set_key = set_key ,
251+ pop_key = pop_key ,
243252 )
244253
245254
@@ -1146,16 +1155,16 @@ def _graph_unflatten(
11461155) -> Node :
11471156 """Recursive helper for graph_unflatten.
11481157
1149- Args:
1150- nodedef: A GraphDef instance or an index to a node in the cache.
1151- state: A mapping from attribute names to variables or subgraphs.
1152- index_to_ref : A mapping from indexes to nodes that have been traversed.
1153- If a node is already in the cache, it won't be traversed again.
1154- f0f6619b-dde6-4466-b699-61c47f268d6b index_ref_cache : A mapping from indexes to existing nodes that can be reused.
1155- When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
1156- object in an empty state and then filled by the unflatten process, as a result
1157- existing graph nodes are mutated to have the new content/topology
1158- specified by the nodedef.
1158+ Args:
1159+ nodedef: A GraphDef instance or an index to a node in the cache.
1160+ state: A mapping from attribute names to variables or subgraphs.
1161+ index_ref : A mapping from indexes to nodes that have been traversed.
1162+ If a node is already in the cache, it won't be traversed again.
1163+ outer_index_outer_ref : A mapping from indexes to existing nodes that can be reused.
1164+ When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
1165+ object in an empty state and then filled by the unflatten process, as a result
1166+ existing graph nodes are mutated to have the new content/topology
1167+ specified by the nodedef.
11591168 """
11601169
11611170 def get_mutable_array (array_refdef : ArrayRefDef , leaf ):
@@ -1168,13 +1177,9 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
11681177 # if array ref exists, update it
11691178 array_ref = outer_index_outer_ref [array_refdef .outer_index ]
11701179 if not variablelib .is_array_ref (array_ref ):
1171- raise RuntimeError (
1172- f'Expected a ArrayRef type but got { array_ref } .'
1173- )
1180+ raise RuntimeError (f'Expected a ArrayRef type but got { array_ref } .' )
11741181 if type (leaf ) is not NoUpdate :
1175- raise RuntimeError (
1176- f'Expected a no update for ArrayRef but got { leaf } .'
1177- )
1182+ raise RuntimeError (f'Expected a no update for ArrayRef but got { leaf } .' )
11781183 elif type (leaf ) in (NoUpdate , Repeated ):
11791184 raise ValueError (
11801185 f"Expected a ArrayRefOutput type but got '{ leaf .value } .'"
@@ -1206,9 +1211,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12061211 if isinstance (value , Variable ):
12071212 value = value .copy () if copy_variables else value
12081213 inner_value = value .raw_value
1209- array_ref = get_mutable_array (
1210- variabledef .array_refdef , inner_value
1211- )
1214+ array_ref = get_mutable_array (variabledef .array_refdef , inner_value )
12121215 if array_ref is not inner_value :
12131216 value .raw_value = array_ref
12141217 else :
@@ -1268,8 +1271,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
12681271 elif type (value ) is MutableArrayAttr :
12691272 array_refdef = next (node_iter )
12701273 assert (
1271- type (array_refdef ) is ArrayRefDef
1272- or type (array_refdef ) is NodeRef
1274+ type (array_refdef ) is ArrayRefDef or type (array_refdef ) is NodeRef
12731275 )
12741276 if type (array_refdef ) is NodeRef :
12751277 array_ref = index_ref [array_refdef .index ]
@@ -1392,7 +1394,7 @@ def _graph_pop(
13921394
13931395 for state , predicate in zip (flat_states , predicates ):
13941396 if predicate (node_path , value ):
1395- if isinstance ( node_impl , PytreeNodeImpl ) :
1397+ if node_impl . pop_key is None :
13961398 raise ValueError (
13971399 f'Cannot pop key { name !r} from node of type { type (node ).__name__ } '
13981400 )
@@ -1441,7 +1443,7 @@ def _update_variable(node: Variable, value):
14411443 for key , value in state .items ():
14421444 # case 1: new state is being added
14431445 if key not in node_dict :
1444- if isinstance ( node_impl , PytreeNodeImpl ) :
1446+ if node_impl . set_key is None :
14451447 raise ValueError (
14461448 f'Cannot set key { key !r} on immutable node of '
14471449 f'type { type (node ).__name__ } '
@@ -1460,22 +1462,16 @@ def _update_variable(node: Variable, value):
14601462 if is_node_leaf (value ):
14611463 raise ValueError (f'Expected a subgraph for { key !r} , but got: { value !r} ' )
14621464 _graph_update_dynamic (current_value , value )
1463- else :
1464- if isinstance (current_value , jax .Array | np .ndarray ):
1465- if isinstance (node_impl , PytreeNodeImpl ):
1466- raise ValueError (
1467- f'Cannot set key { key !r} on immutable node of '
1468- f'type { type (node ).__name__ } '
1469- )
1470- node_impl .set_key (node , key , value )
1471- continue
1472- elif not isinstance (current_value , Variable ):
1473- # case 3: state leaf is being updated
1474- raise ValueError (
1475- f'Trying to update a non-Variable attribute { key !r} with a Variable: '
1476- f'{ value !r} '
1477- )
1465+ elif isinstance (current_value , Variable ):
14781466 _update_variable (current_value , value )
1467+ elif node_impl .set_key is not None :
1468+ node_impl .set_key (node , key , value )
1469+ else :
1470+ raise ValueError (
1471+ f'Cannot set key { key !r} on immutable node of '
1472+ f'type { type (node ).__name__ } '
1473+ )
1474+
14791475
14801476
14811477# --------------------------------------------------------
@@ -2621,9 +2617,7 @@ def find_duplicates(
26212617
26222618
26232619def _mutable_like (path , x ):
2624- return (
2625- isinstance (x , Variable ) and x .has_ref
2626- ) or variablelib .is_array_ref (x )
2620+ return (isinstance (x , Variable ) and x .has_ref ) or variablelib .is_array_ref (x )
26272621
26282622
26292623def to_arrays (
@@ -3021,6 +3015,8 @@ def _unflatten_pytree(
30213015 type = GenericPytree ,
30223016 flatten = _flatten_pytree ,
30233017 unflatten = _unflatten_pytree , # type: ignore
3018+ set_key = None ,
3019+ pop_key = None ,
30243020)
30253021
30263022# common pytrees
@@ -3036,15 +3032,36 @@ def _unflatten_pytree(
30363032 flatten = lambda x : (list (enumerate (x )), None ),
30373033 unflatten = lambda nodes , _ : tuple (value for _ , value in nodes ), # type: ignore
30383034)
3035+
3036+
3037+ def _mutable_mapping_set_key (x : tp .MutableMapping [Key , tp .Any ], key : Key , value : tp .Any ):
3038+ x [key ] = value
3039+
3040+
3041+ def _mutable_mapping_pop_key (x : tp .MutableMapping [Key , tp .Any ], key : Key ):
3042+ x .pop (key )
3043+
3044+
30393045# dict
30403046register_pytree_node_type (
30413047 dict ,
30423048 flatten = lambda x : (sorted (x .items ()), None ),
3043- unflatten = lambda nodes , _ : {key : value for key , value in nodes }, # type: ignore
3049+ unflatten = lambda nodes , _ : dict (nodes ), # type: ignore
3050+ set_key = _mutable_mapping_set_key ,
3051+ pop_key = _mutable_mapping_pop_key ,
3052+ )
3053+ # State
3054+ register_pytree_node_type (
3055+ State ,
3056+ flatten = lambda x : (sorted (x .raw_mapping .items ()), None ),
3057+ unflatten = lambda nodes , _ : State (nodes ), # type: ignore
3058+ set_key = _mutable_mapping_set_key ,
3059+ pop_key = _mutable_mapping_pop_key ,
30443060)
30453061# None
30463062register_pytree_node_type (
30473063 type (None ),
30483064 flatten = lambda x : ([], None ),
30493065 unflatten = lambda _ , __ : None , # type: ignore
30503066)
3067+
0 commit comments