1
1
""" Functions for decoding dataclass fields from "raw" values (e.g. from json).
2
2
"""
3
+ from __future__ import annotations
4
+
3
5
import inspect
4
6
import warnings
5
7
from collections import OrderedDict
9
11
from functools import lru_cache , partial
10
12
from logging import getLogger
11
13
from pathlib import Path
12
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union
14
+ from typing import Any , Callable , TypeVar
13
15
14
16
from simple_parsing .annotation_utils .get_field_annotations import (
15
17
evaluate_string_annotation ,
16
18
)
17
19
from simple_parsing .utils import (
18
20
get_bound ,
21
+ get_forward_arg ,
19
22
get_type_arguments ,
23
+ is_dataclass_type ,
20
24
is_dict ,
21
25
is_enum ,
22
26
is_forward_ref ,
35
39
V = TypeVar ("V" )
36
40
37
41
# Dictionary mapping from types/type annotations to their decoding functions.
38
- _decoding_fns : Dict [ Type [T ], Callable [[Any ], T ]] = {
42
+ _decoding_fns : dict [ type [T ], Callable [[Any ], T ]] = {
39
43
# the 'primitive' types are decoded using the type fn as a constructor.
40
44
t : t
41
45
for t in [str , float , int , bytes ]
@@ -51,7 +55,7 @@ def decode_bool(v: Any) -> bool:
51
55
_decoding_fns [bool ] = decode_bool
52
56
53
57
54
- def decode_field (field : Field , raw_value : Any , containing_dataclass : Optional [ type ] = None ) -> Any :
58
+ def decode_field (field : Field , raw_value : Any , containing_dataclass : type | None = None ) -> Any :
55
59
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.
56
60
57
61
When serializing a dataclass to json, all objects are converted to dicts.
@@ -84,7 +88,7 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[ty
84
88
85
89
86
90
@lru_cache (maxsize = 100 )
87
- def get_decoding_fn (t : Type [T ]) -> Callable [[ Any ] , T ]:
91
+ def get_decoding_fn (type_annotation : type [T ] | str ) -> Callable [... , T ]:
88
92
"""Fetches/Creates a decoding function for the given type annotation.
89
93
90
94
This decoding function can then be used to create an instance of the type
@@ -111,67 +115,54 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
111
115
A function that decodes a 'raw' value to an instance of type `t`.
112
116
113
117
"""
114
- # cache_info = get_decoding_fn.cache_info()
115
- # logger.debug(f"called for type {t}! Cache info: {cache_info}")
116
-
117
- def _get_potential_keys (annotation : str ) -> List [str ]:
118
- # Type annotation is a string.
119
- # This can happen when the `from __future__ import annotations` feature is used.
120
- potential_keys : List [Type ] = []
121
- for key in _decoding_fns :
122
- if inspect .isclass (key ):
123
- if key .__qualname__ == annotation :
124
- # Qualname is more specific, there can't possibly be another match, so break.
125
- potential_keys .append (key )
126
- break
127
- if key .__qualname__ == annotation :
128
- # For just __name__, there could be more than one match.
129
- potential_keys .append (key )
130
- return potential_keys
131
-
132
- if isinstance (t , str ):
133
- if t in _decoding_fns :
134
- return _decoding_fns [t ]
135
-
136
- potential_keys = _get_potential_keys (t )
137
-
138
- if not potential_keys :
139
- # Try to replace the new-style annotation str with the old style syntax, and see if we
140
- # find a match.
141
- # try:
142
- try :
143
- evaluated_t = evaluate_string_annotation (t )
144
- # NOTE: We now have a 'live'/runtime type annotation object from the typing module.
145
- except (ValueError , TypeError ) as err :
146
- logger .error (f"Unable to evaluate the type annotation string { t } : { err } ." )
147
- else :
148
- if evaluated_t in _decoding_fns :
149
- return _decoding_fns [evaluated_t ]
150
- # If we still don't have this annotation stored in our dict of known functions, we
151
- # recurse, to try to deconstruct this annotation into its parts, and construct the
152
- # decoding function for the annotation. If this doesn't work, we just raise the
153
- # errors.
154
- return get_decoding_fn (evaluated_t )
155
-
156
- raise ValueError (
157
- f"Couldn't find a decoding function for the string annotation '{ t } '.\n "
158
- f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
159
- f"to work on fixing it.\n "
160
- f"Types with a known decoding function: { list (_decoding_fns .keys ())} "
118
+ from .serializable import from_dict
119
+
120
+ logger .debug (f"Getting the decoding function for { type_annotation !r} " )
121
+
122
+ if isinstance (type_annotation , str ):
123
+ # Check first if there are any matching registered decoding functions.
124
+ # TODO: Might be better to actually use the scope of the field, right?
125
+ matching_entries = {
126
+ key : decoding_fn
127
+ for key , decoding_fn in _decoding_fns .items ()
128
+ if (inspect .isclass (key ) and key .__name__ == type_annotation )
129
+ }
130
+ if len (matching_entries ) == 1 :
131
+ _ , decoding_fn = matching_entries .popitem ()
132
+ return decoding_fn
133
+ elif len (matching_entries ) > 1 :
134
+ # Multiple decoding functions match the type. Can't tell.
135
+ logger .warning (
136
+ RuntimeWarning (
137
+ f"More than one potential decoding functions were found for types that match "
138
+ f"the string annotation { type_annotation !r} . This will simply try each one "
139
+ f"and return the first one that works."
140
+ )
161
141
)
162
- if len (potential_keys ) == 1 :
163
- t = potential_keys [0 ]
142
+ return try_functions (* (decoding_fn for _ , decoding_fn in matching_entries .items ()))
164
143
else :
165
- raise ValueError (
166
- f"Multiple decoding functions registered for a type { t } : { potential_keys } \n "
167
- f"This could be a bug, but try to use different names for each type, or add the "
168
- f"modules they come from as a prefix, perhaps?"
169
- )
144
+ # Try to evaluate the string annotation.
145
+ t = evaluate_string_annotation (type_annotation )
146
+
147
+ elif is_forward_ref (type_annotation ):
148
+ forward_arg : str = get_forward_arg (type_annotation )
149
+ # Recurse until we've resolved the forward reference.
150
+ return get_decoding_fn (forward_arg )
151
+
152
+ else :
153
+ t = type_annotation
154
+
155
+ logger .debug (f"{ type_annotation !r} -> { t !r} " )
156
+
157
+ # T should now be a type or one of the objects from the typing module.
170
158
171
159
if t in _decoding_fns :
172
160
# The type has a dedicated decoding function.
173
161
return _decoding_fns [t ]
174
162
163
+ if is_dataclass_type (t ):
164
+ return partial (from_dict , t )
165
+
175
166
if t is Any :
176
167
logger .debug (f"Decoding an Any type: { t } " )
177
168
return no_op
@@ -214,31 +205,6 @@ def _get_potential_keys(annotation: str) -> List[str]:
214
205
logger .debug (f"Decoding an Enum field: { t } " )
215
206
return decode_enum (t )
216
207
217
- from .serializable import SerializableMixin , get_dataclass_types_from_forward_ref
218
-
219
- if is_forward_ref (t ):
220
- dcs = get_dataclass_types_from_forward_ref (t )
221
- if len (dcs ) == 1 :
222
- dc = dcs [0 ]
223
- return dc .from_dict
224
- if len (dcs ) > 1 :
225
- logger .warning (
226
- RuntimeWarning (
227
- f"More than one potential Serializable dataclass was found with a name matching "
228
- f"the type annotation { t } . This will simply try each one, and return the "
229
- f"first one that works. Potential classes: { dcs } "
230
- )
231
- )
232
- return try_functions (* [partial (dc .from_dict , drop_extra_fields = False ) for dc in dcs ])
233
- else :
234
- # No idea what the forward ref refers to!
235
- logger .warning (
236
- f"Unable to find a dataclass that matches the forward ref { t } inside the "
237
- f"registered { SerializableMixin } subclasses. Leaving the value as-is."
238
- f"(Consider using Serializable or FrozenSerializable as a base class?)."
239
- )
240
- return no_op
241
-
242
208
if is_typevar (t ):
243
209
bound = get_bound (t )
244
210
logger .debug (f"Decoding a typevar: { t } , bound type is { bound } ." )
@@ -256,31 +222,31 @@ def _get_potential_keys(annotation: str) -> List[str]:
256
222
return try_constructor (t )
257
223
258
224
259
- def _register (t : Type , func : Callable ) -> None :
225
+ def _register (t : type , func : Callable ) -> None :
260
226
if t not in _decoding_fns :
261
227
# logger.debug(f"Registering the type {t} with decoding function {func}")
262
228
_decoding_fns [t ] = func
263
229
264
230
265
- def register_decoding_fn (some_type : Type [T ], function : Callable [[Any ], T ]) -> None :
231
+ def register_decoding_fn (some_type : type [T ], function : Callable [[Any ], T ]) -> None :
266
232
"""Register a decoding function for the type `some_type`."""
267
233
_register (some_type , function )
268
234
269
235
270
- def decode_optional (t : Type [T ]) -> Callable [[Optional [ Any ]], Optional [ T ] ]:
236
+ def decode_optional (t : type [T ]) -> Callable [[Any | None ], T | None ]:
271
237
decode = get_decoding_fn (t )
272
238
273
- def _decode_optional (val : Optional [ Any ] ) -> Optional [ T ] :
239
+ def _decode_optional (val : Any | None ) -> T | None :
274
240
return val if val is None else decode (val )
275
241
276
242
return _decode_optional
277
243
278
244
279
- def try_functions (* funcs : Callable [[Any ], T ]) -> Callable [[Any ], Union [ T , Any ] ]:
245
+ def try_functions (* funcs : Callable [[Any ], T ]) -> Callable [[Any ], T | Any ]:
280
246
"""Tries to use the functions in succession, else returns the same value unchanged."""
281
247
282
- def _try_functions (val : Any ) -> Union [ T , Any ] :
283
- e : Optional [ Exception ] = None
248
+ def _try_functions (val : Any ) -> T | Any :
249
+ e : Exception | None = None
284
250
for func in funcs :
285
251
try :
286
252
return func (val )
@@ -293,30 +259,30 @@ def _try_functions(val: Any) -> Union[T, Any]:
293
259
return _try_functions
294
260
295
261
296
- def decode_union (* types : Type [T ]) -> Callable [[Any ], Union [ T , Any ] ]:
262
+ def decode_union (* types : type [T ]) -> Callable [[Any ], T | Any ]:
297
263
types = list (types )
298
264
optional = type (None ) in types
299
265
# Partition the Union into None and non-None types.
300
266
while type (None ) in types :
301
267
types .remove (type (None ))
302
268
303
- decoding_fns : List [Callable [[Any ], T ]] = [
269
+ decoding_fns : list [Callable [[Any ], T ]] = [
304
270
decode_optional (t ) if optional else get_decoding_fn (t ) for t in types
305
271
]
306
272
# Try using each of the non-None types, in succession. Worst case, return the value.
307
273
return try_functions (* decoding_fns )
308
274
309
275
310
- def decode_list (t : Type [T ]) -> Callable [[List [Any ]], List [T ]]:
276
+ def decode_list (t : type [T ]) -> Callable [[list [Any ]], list [T ]]:
311
277
decode_item = get_decoding_fn (t )
312
278
313
- def _decode_list (val : List [Any ]) -> List [T ]:
279
+ def _decode_list (val : list [Any ]) -> list [T ]:
314
280
return [decode_item (v ) for v in val ]
315
281
316
282
return _decode_list
317
283
318
284
319
- def decode_tuple (* tuple_item_types : Type [T ]) -> Callable [[List [T ]], Tuple [T , ...]]:
285
+ def decode_tuple (* tuple_item_types : type [T ]) -> Callable [[list [T ]], tuple [T , ...]]:
320
286
"""Makes a parsing function for creating tuples.
321
287
322
288
Can handle tuples with different item types, for instance:
@@ -338,7 +304,7 @@ def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...
338
304
# Note, if there are more values than types in the tuple type, then the
339
305
# last type is used.
340
306
341
- def _decode_tuple (val : Tuple [Any , ...]) -> Tuple [T , ...]:
307
+ def _decode_tuple (val : tuple [Any , ...]) -> tuple [T , ...]:
342
308
if has_ellipsis :
343
309
return tuple (decoding_fn (v ) for v in val )
344
310
else :
@@ -347,7 +313,7 @@ def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
347
313
return _decode_tuple
348
314
349
315
350
- def decode_set (item_type : Type [T ]) -> Callable [[List [T ]], Set [T ]]:
316
+ def decode_set (item_type : type [T ]) -> Callable [[list [T ]], set [T ]]:
351
317
"""Makes a parsing function for creating sets with items of type `item_type`.
352
318
353
319
Args:
@@ -359,13 +325,13 @@ def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
359
325
# Get the parse fn for a list of items of type `item_type`.
360
326
parse_list_fn = decode_list (item_type )
361
327
362
- def _decode_set (val : List [Any ]) -> Set [T ]:
328
+ def _decode_set (val : list [Any ]) -> set [T ]:
363
329
return set (parse_list_fn (val ))
364
330
365
331
return _decode_set
366
332
367
333
368
- def decode_dict (K_ : Type [K ], V_ : Type [V ]) -> Callable [[List [ Tuple [Any , Any ]]], Dict [K , V ]]:
334
+ def decode_dict (K_ : type [K ], V_ : type [V ]) -> Callable [[list [ tuple [Any , Any ]]], dict [K , V ]]:
369
335
"""Creates a decoding function for a dict type. Works with OrderedDict too.
370
336
371
337
Args:
@@ -379,8 +345,8 @@ def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], D
379
345
decode_k = get_decoding_fn (K_ )
380
346
decode_v = get_decoding_fn (V_ )
381
347
382
- def _decode_dict (val : Union [ Dict [ Any , Any ], List [ Tuple [Any , Any ]]] ) -> Dict [K , V ]:
383
- result : Dict [K , V ] = {}
348
+ def _decode_dict (val : dict [ Any , Any ] | list [ tuple [Any , Any ]]) -> dict [K , V ]:
349
+ result : dict [K , V ] = {}
384
350
if isinstance (val , list ):
385
351
result = OrderedDict ()
386
352
items = val
@@ -399,7 +365,7 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V
399
365
return _decode_dict
400
366
401
367
402
- def decode_enum (item_type : Type [Enum ]) -> Callable [[str ], Enum ]:
368
+ def decode_enum (item_type : type [Enum ]) -> Callable [[str ], Enum ]:
403
369
"""
404
370
Creates a decoding function for an enum type.
405
371
@@ -428,7 +394,7 @@ def no_op(v: T) -> T:
428
394
return v
429
395
430
396
431
- def try_constructor (t : Type [T ]) -> Callable [[Any ], Union [ T , Any ] ]:
397
+ def try_constructor (t : type [T ]) -> Callable [[Any ], T | Any ]:
432
398
"""Tries to use the type as a constructor. If that fails, returns the value as-is.
433
399
434
400
Args:
0 commit comments