@@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
406406
407407def _graph_node_meta_call (cls : tp .Type [P ], * args , ** kwargs ) -> P :
408408 node = cls .__new__ (cls , * args , ** kwargs )
409- vars_obj = vars (node )
410409 object .__setattr__ (node , '_pytree__state' , PytreeState ())
411410 object .__setattr__ (node , '_pytree__nodes' , cls ._pytree__nodes )
412411 cls ._pytree_meta_construct (node , * args , ** kwargs )
@@ -498,6 +497,11 @@ def __init_subclass__(
498497 ** kwargs ,
499498 ) -> None :
500499 super ().__init_subclass__ (** kwargs )
500+ if slots := getattr (cls , '__slots__' , ()):
501+ raise TypeError (
502+ 'Pytree currently does not support __slots__, '
503+ f"found __slots__={ slots } in '{ cls .__name__ } '."
504+ )
501505 cls ._pytree__is_pytree = pytree
502506
503507 graph .register_graph_node_type (
@@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self):
874878 else :
875879 key_fn = None
876880 node_attributes = self ._pytree__nodes
877- node_names : list [str ] = []
881+ node_keys : list [str | int ] = []
878882 node_attrs : list [tuple [tp .Any , tp .Any ]] = []
879- static_attrs : list [tuple [str , tp .Any ]] = []
880- for name , value in sorted (obj_items , key = key_fn ):
881- if name in node_attributes and node_attributes [name ]:
882- node_names .append (name )
883+ static_keys : list [str | int ] = []
884+ static_attrs : list [tp .Any ] = []
885+ for key , value in sorted (obj_items , key = key_fn ):
886+ # get string representation of the key because
887+ # node_attributes keys are strings
888+ key_str = _get_str (key )
889+ if key_str in node_attributes and node_attributes [key_str ]:
890+ node_keys .append (key )
883891 node_attrs .append ((
884- jax .tree_util .GetAttrKey (name )
885- if isinstance (name , str )
886- else jax .tree_util .SequenceKey (name ),
892+ jax .tree_util .GetAttrKey (key )
893+ if isinstance (key , str )
894+ else jax .tree_util .SequenceKey (key ),
887895 value ,
888896 ))
889897 else :
890- static_attrs .append ((name , value ))
898+ static_keys .append (key )
899+ static_attrs .append (value )
891900
892- return node_attrs , (tuple (node_names ), tuple (static_attrs ))
901+ return (
902+ node_attrs ,
903+ (tuple (node_keys ), tuple (static_keys ), tuple (static_attrs )),
904+ )
893905
894906 def _pytree__flatten (self ):
895907 obj_items = vars (self ).items ()
@@ -899,34 +911,43 @@ def _pytree__flatten(self):
899911 else :
900912 key_fn = None
901913 node_attributes = self ._pytree__nodes
902- node_names : list [str ] = []
914+ node_keys : list [str | int ] = []
903915 node_attrs : list [tp .Any ] = []
904- static_attrs : list [tuple [str , tp .Any ]] = []
905- for name , value in sorted (obj_items , key = key_fn ):
906- if name in node_attributes and node_attributes [name ]:
907- node_names .append (name )
916+ static_keys : list [str | int ] = []
917+ static_attrs : list [tp .Any ] = []
918+ for key , value in sorted (obj_items , key = key_fn ):
919+ # get string representation of the key because
920+ # node_attributes keys are strings
921+ key_str = _get_str (key )
922+ if key_str in node_attributes and node_attributes [key_str ]:
923+ node_keys .append (key )
908924 node_attrs .append (value )
909925 else :
910- static_attrs .append ((name , value ))
926+ static_keys .append (key )
927+ static_attrs .append (value )
911928
912- return node_attrs , (tuple (node_names ), tuple (static_attrs ))
929+ return (
930+ node_attrs ,
931+ (tuple (node_keys ), tuple (static_keys ), tuple (static_attrs )),
932+ )
913933
914934 @classmethod
915935 def _pytree__unflatten (
916936 cls ,
917- static : tuple [tuple [str , ...], tuple [tuple [ str , tp .Any ] , ...]],
937+ static : tuple [tuple [str | int , ...], tuple [str | int , ...], tuple [ tp .Any , ...]],
918938 node_attrs : tp .Iterable [tp .Any ],
919939 ):
920- node_names , static_attrs = static
940+ node_keys , static_keys , static_attrs = static
921941 obj = object .__new__ (cls )
922- vars_obj = vars (obj )
923942 if cls ._pytree__has_int_keys :
924- node_names = tuple (
925- str (name ) if isinstance (name , int ) else name for name in node_names
926- )
927- for name , value in zip (node_names , node_attrs , strict = True ):
943+ node_keys_iter = map (_get_str , node_keys )
944+ static_keys_iter = map (_get_str , static_keys )
945+ else :
946+ node_keys_iter = node_keys
947+ static_keys_iter = static_keys
948+ for name , value in zip (node_keys_iter , node_attrs , strict = True ):
928949 object .__setattr__ (obj , name , value )
929- for name , value in static_attrs :
950+ for name , value in zip ( static_keys_iter , static_attrs , strict = True ) :
930951 object .__setattr__ (obj , name , value )
931952 return obj
932953
@@ -946,7 +967,16 @@ def _graph_node_flatten(self):
946967 def _graph_node_set_key (self , key , value : tp .Any ):
947968 if self ._pytree__has_int_keys and isinstance (key , int ):
948969 key = str (key )
949- setattr (self , key , value )
970+ if not isinstance (key , str ):
971+ raise KeyError (f'Invalid key: { key !r} ' )
972+ elif (
973+ hasattr (self , key )
974+ and isinstance (variable := getattr (self , key ), Variable )
975+ and isinstance (value , Variable )
976+ ):
977+ variable .update_from_state (value )
978+ else :
979+ setattr (self , key , value )
950980
951981 def _graph_node_pop_key (self , key ):
952982 if self ._pytree__has_int_keys and isinstance (key , int ):
@@ -978,7 +1008,8 @@ def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
9781008 (str (name ) if isinstance (name , int ) else name , value )
9791009 for name , value in attributes
9801010 )
981- vars (self ).update (attributes )
1011+ for name , value in attributes :
1012+ object .__setattr__ (self , name , value )
9821013
9831014 if tp .TYPE_CHECKING :
9841015 def __call__ (self , * args : tp .Any , ** kwargs : tp .Any ) -> tp .Any : ...
@@ -1000,4 +1031,7 @@ def _maybe_int(x):
10001031 try :
10011032 return int (x )
10021033 except (ValueError , TypeError ):
1003- return x
1034+ return x
1035+
1036+ def _get_str (x ):
1037+ return x if isinstance (x , str ) else str (x )
0 commit comments