11from collections import OrderedDict
22from collections .abc import Mapping
3- from typing import Tuple
3+ from typing import List , Tuple
44
55
66class Serializable :
7- _state_dict_all_req_keys : Tuple = ()
8- _state_dict_one_of_opt_keys : Tuple = ()
7+ _state_dict_all_req_keys : Tuple [str , ...] = ()
8+ _state_dict_one_of_opt_keys : Tuple [Tuple [str , ...], ...] = ((),)
9+
10+ def __init__ (self ) -> None :
11+ self ._state_dict_user_keys : List [str ] = []
12+
13+ @property
14+ def state_dict_user_keys (self ) -> List :
15+ return self ._state_dict_user_keys
916
1017 def state_dict (self ) -> OrderedDict :
1118 raise NotImplementedError
@@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926 raise ValueError (
2027 f"Required state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
2128 )
22- opts = [k in state_dict for k in self ._state_dict_one_of_opt_keys ]
23- if len (opts ) > 0 and ((not any (opts )) or (all (opts ))):
24- raise ValueError (f"state_dict should contain only one of '{ self ._state_dict_one_of_opt_keys } ' keys" )
29+
30+ # Handle groups of one-of optional keys
31+ for one_of_opt_keys in self ._state_dict_one_of_opt_keys :
32+ if len (one_of_opt_keys ) > 0 :
33+ opts = [k in state_dict for k in one_of_opt_keys ]
34+ num_present = sum (opts )
35+ if num_present == 0 :
36+ raise ValueError (f"state_dict should contain at least one of '{ one_of_opt_keys } ' keys" )
37+ if num_present > 1 :
38+ raise ValueError (f"state_dict should contain only one of '{ one_of_opt_keys } ' keys" )
39+
40+ # Check user keys
41+ if hasattr (self , "_state_dict_user_keys" ) and isinstance (self ._state_dict_user_keys , list ):
42+ for k in self ._state_dict_user_keys :
43+ if k not in state_dict :
44+ raise ValueError (
45+ f"Required user state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
46+ )
0 commit comments