@@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
406406
407407def _graph_node_meta_call (cls : tp .Type [P ], * args , ** kwargs ) -> P :
408408 node = cls .__new__ (cls , * args , ** kwargs )
409- vars_obj = vars (node )
410409 object .__setattr__ (node , '_pytree__state' , PytreeState ())
411410 object .__setattr__ (node , '_pytree__nodes' , cls ._pytree__nodes )
412411 cls ._pytree_meta_construct (node , * args , ** kwargs )
@@ -498,6 +497,11 @@ def __init_subclass__(
498497 ** kwargs ,
499498 ) -> None :
500499 super ().__init_subclass__ (** kwargs )
500+ if slots := getattr (cls , '__slots__' , ()):
501+ raise TypeError (
502+ 'Pytree currently does not support __slots__, '
503+ f"found __slots__={ slots } in '{ cls .__name__ } '."
504+ )
501505 cls ._pytree__is_pytree = pytree
502506
503507 graph .register_graph_node_type (
@@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self):
874878 else :
875879 key_fn = None
876880 node_attributes = self ._pytree__nodes
877- node_names : list [str ] = []
881+ node_keys : list [str | int ] = []
878882 node_attrs : list [tuple [tp .Any , tp .Any ]] = []
879- static_attrs : list [tuple [str , tp .Any ]] = []
880- for name , value in sorted (obj_items , key = key_fn ):
881- if name in node_attributes and node_attributes [name ]:
882- node_names .append (name )
883+ static_keys : list [str | int ] = []
884+ static_attrs : list [tp .Any ] = []
885+ for key , value in sorted (obj_items , key = key_fn ):
886+ # get string representation of the key because
887+ # node_attributes keys are strings
888+ key_str = _get_str (key )
889+ if key_str in node_attributes and node_attributes [key_str ]:
890+ node_keys .append (key )
883891 node_attrs .append ((
884- jax .tree_util .GetAttrKey (name )
885- if isinstance (name , str )
886- else jax .tree_util .SequenceKey (name ),
892+ jax .tree_util .GetAttrKey (key )
893+ if isinstance (key , str )
894+ else jax .tree_util .SequenceKey (key ),
887895 value ,
888896 ))
889897 else :
890- static_attrs .append ((name , value ))
898+ static_keys .append (key )
899+ static_attrs .append (value )
891900
892- return node_attrs , (tuple (node_names ), tuple (static_attrs ))
901+ return (
902+ node_attrs ,
903+ (tuple (node_keys ), tuple (static_keys ), tuple (static_attrs )),
904+ )
893905
894906 def _pytree__flatten (self ):
895907 obj_items = vars (self ).items ()
@@ -899,35 +911,38 @@ def _pytree__flatten(self):
899911 else :
900912 key_fn = None
901913 node_attributes = self ._pytree__nodes
902- node_names : list [str ] = []
914+ node_keys : list [str | int ] = []
903915 node_attrs : list [tp .Any ] = []
904- static_attrs : list [tuple [str , tp .Any ]] = []
905- for name , value in sorted (obj_items , key = key_fn ):
906- if name in node_attributes and node_attributes [name ]:
907- node_names .append (name )
916+ static_keys : list [str | int ] = []
917+ static_attrs : list [tp .Any ] = []
918+ for key , value in sorted (obj_items , key = key_fn ):
919+ # get string representation of the key because
920+ # node_attributes keys are strings
921+ key_str = _get_str (key )
922+ if key_str in node_attributes and node_attributes [key_str ]:
923+ node_keys .append (key )
908924 node_attrs .append (value )
909925 else :
910- static_attrs .append ((name , value ))
926+ static_keys .append (key )
927+ static_attrs .append (value )
911928
912- return node_attrs , (tuple (node_names ), tuple (static_attrs ))
929+ return (
930+ node_attrs ,
931+ (tuple (node_keys ), tuple (static_keys ), tuple (static_attrs )),
932+ )
913933
914934 @classmethod
915935 def _pytree__unflatten (
916936 cls ,
917- static : tuple [tuple [str , ... ], tuple [ tuple [ str , tp . Any ], ... ]],
937+ static : tuple [tp . Iterable [str | int ], tp . Iterable [ str | int ], tp . Iterable [ tp . Any ]],
918938 node_attrs : tp .Iterable [tp .Any ],
919939 ):
920- node_names , static_attrs = static
940+ node_keys , static_keys , static_attrs = static
921941 obj = object .__new__ (cls )
922- vars_obj = vars (obj )
923- if cls ._pytree__has_int_keys :
924- node_names = tuple (
925- str (name ) if isinstance (name , int ) else name for name in node_names
926- )
927- for name , value in zip (node_names , node_attrs , strict = True ):
928- object .__setattr__ (obj , name , value )
929- for name , value in static_attrs :
930- object .__setattr__ (obj , name , value )
942+ for name , value in zip (node_keys , node_attrs , strict = True ):
943+ object .__setattr__ (obj , _get_str (name ), value )
944+ for name , value in zip (static_keys , static_attrs , strict = True ):
945+ object .__setattr__ (obj , _get_str (name ), value )
931946 return obj
932947
933948 # -------------------------
@@ -946,7 +961,16 @@ def _graph_node_flatten(self):
946961 def _graph_node_set_key (self , key , value : tp .Any ):
947962 if self ._pytree__has_int_keys and isinstance (key , int ):
948963 key = str (key )
949- setattr (self , key , value )
964+ if not isinstance (key , str ):
965+ raise KeyError (f'Invalid key: { key !r} ' )
966+ elif (
967+ hasattr (self , key )
968+ and isinstance (variable := getattr (self , key ), Variable )
969+ and isinstance (value , Variable )
970+ ):
971+ variable .update_from_state (value )
972+ else :
973+ setattr (self , key , value )
950974
951975 def _graph_node_pop_key (self , key ):
952976 if self ._pytree__has_int_keys and isinstance (key , int ):
@@ -972,13 +996,9 @@ def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
972996 def _graph_node_clear (self ):
973997 vars (self ).clear ()
974998
975- def _graph_node_init (self , attributes : tp .Iterable [tuple [str , tp .Any ]]):
976- if self ._pytree__has_int_keys :
977- attributes = (
978- (str (name ) if isinstance (name , int ) else name , value )
979- for name , value in attributes
980- )
981- vars (self ).update (attributes )
999+ def _graph_node_init (self , attributes : tp .Iterable [tuple [str | int , tp .Any ]]):
1000+ for name , value in attributes :
1001+ object .__setattr__ (self , _get_str (name ), value )
9821002
9831003 if tp .TYPE_CHECKING :
9841004 def __call__ (self , * args : tp .Any , ** kwargs : tp .Any ) -> tp .Any : ...
@@ -1000,4 +1020,7 @@ def _maybe_int(x):
10001020 try :
10011021 return int (x )
10021022 except (ValueError , TypeError ):
1003- return x
1023+ return x
1024+
1025+ def _get_str (x ):
1026+ return x if isinstance (x , str ) else str (x )
0 commit comments