Skip to content

Commit 3122f6e

Browse files
authored
Broaden from_dict applicability to non-Serializable dataclasses (#217)
* [temp] Save local changes Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with resolving of forward references Signed-off-by: Fabrice Normandin <[email protected]> * Fix logging format string for py37 Signed-off-by: Fabrice Normandin <[email protected]> * Move some test files over to test/helpers Signed-off-by: Fabrice Normandin <[email protected]> * Fuse test_serialization into test_from_dict Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 404f7f3 commit 3122f6e

File tree

6 files changed

+164
-108
lines changed

6 files changed

+164
-108
lines changed

simple_parsing/helpers/serialization/decoding.py

+69-103
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" Functions for decoding dataclass fields from "raw" values (e.g. from json).
22
"""
3+
from __future__ import annotations
4+
35
import inspect
46
import warnings
57
from collections import OrderedDict
@@ -9,14 +11,16 @@
911
from functools import lru_cache, partial
1012
from logging import getLogger
1113
from pathlib import Path
12-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
14+
from typing import Any, Callable, TypeVar
1315

1416
from simple_parsing.annotation_utils.get_field_annotations import (
1517
evaluate_string_annotation,
1618
)
1719
from simple_parsing.utils import (
1820
get_bound,
21+
get_forward_arg,
1922
get_type_arguments,
23+
is_dataclass_type,
2024
is_dict,
2125
is_enum,
2226
is_forward_ref,
@@ -35,7 +39,7 @@
3539
V = TypeVar("V")
3640

3741
# Dictionary mapping from types/type annotations to their decoding functions.
38-
_decoding_fns: Dict[Type[T], Callable[[Any], T]] = {
42+
_decoding_fns: dict[type[T], Callable[[Any], T]] = {
3943
# the 'primitive' types are decoded using the type fn as a constructor.
4044
t: t
4145
for t in [str, float, int, bytes]
@@ -51,7 +55,7 @@ def decode_bool(v: Any) -> bool:
5155
_decoding_fns[bool] = decode_bool
5256

5357

54-
def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[type] = None) -> Any:
58+
def decode_field(field: Field, raw_value: Any, containing_dataclass: type | None = None) -> Any:
5559
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.
5660
5761
When serializing a dataclass to json, all objects are converted to dicts.
@@ -84,7 +88,7 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[ty
8488

8589

8690
@lru_cache(maxsize=100)
87-
def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
91+
def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
8892
"""Fetches/Creates a decoding function for the given type annotation.
8993
9094
This decoding function can then be used to create an instance of the type
@@ -111,67 +115,54 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
111115
A function that decodes a 'raw' value to an instance of type `t`.
112116
113117
"""
114-
# cache_info = get_decoding_fn.cache_info()
115-
# logger.debug(f"called for type {t}! Cache info: {cache_info}")
116-
117-
def _get_potential_keys(annotation: str) -> List[str]:
118-
# Type annotation is a string.
119-
# This can happen when the `from __future__ import annotations` feature is used.
120-
potential_keys: List[Type] = []
121-
for key in _decoding_fns:
122-
if inspect.isclass(key):
123-
if key.__qualname__ == annotation:
124-
# Qualname is more specific, there can't possibly be another match, so break.
125-
potential_keys.append(key)
126-
break
127-
if key.__qualname__ == annotation:
128-
# For just __name__, there could be more than one match.
129-
potential_keys.append(key)
130-
return potential_keys
131-
132-
if isinstance(t, str):
133-
if t in _decoding_fns:
134-
return _decoding_fns[t]
135-
136-
potential_keys = _get_potential_keys(t)
137-
138-
if not potential_keys:
139-
# Try to replace the new-style annotation str with the old style syntax, and see if we
140-
# find a match.
141-
# try:
142-
try:
143-
evaluated_t = evaluate_string_annotation(t)
144-
# NOTE: We now have a 'live'/runtime type annotation object from the typing module.
145-
except (ValueError, TypeError) as err:
146-
logger.error(f"Unable to evaluate the type annotation string {t}: {err}.")
147-
else:
148-
if evaluated_t in _decoding_fns:
149-
return _decoding_fns[evaluated_t]
150-
# If we still don't have this annotation stored in our dict of known functions, we
151-
# recurse, to try to deconstruct this annotation into its parts, and construct the
152-
# decoding function for the annotation. If this doesn't work, we just raise the
153-
# errors.
154-
return get_decoding_fn(evaluated_t)
155-
156-
raise ValueError(
157-
f"Couldn't find a decoding function for the string annotation '{t}'.\n"
158-
f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
159-
f"to work on fixing it.\n"
160-
f"Types with a known decoding function: {list(_decoding_fns.keys())}"
118+
from .serializable import from_dict
119+
120+
logger.debug(f"Getting the decoding function for {type_annotation!r}")
121+
122+
if isinstance(type_annotation, str):
123+
# Check first if there are any matching registered decoding functions.
124+
# TODO: Might be better to actually use the scope of the field, right?
125+
matching_entries = {
126+
key: decoding_fn
127+
for key, decoding_fn in _decoding_fns.items()
128+
if (inspect.isclass(key) and key.__name__ == type_annotation)
129+
}
130+
if len(matching_entries) == 1:
131+
_, decoding_fn = matching_entries.popitem()
132+
return decoding_fn
133+
elif len(matching_entries) > 1:
134+
# Multiple decoding functions match the type. Can't tell.
135+
logger.warning(
136+
RuntimeWarning(
137+
f"More than one potential decoding functions were found for types that match "
138+
f"the string annotation {type_annotation!r}. This will simply try each one "
139+
f"and return the first one that works."
140+
)
161141
)
162-
if len(potential_keys) == 1:
163-
t = potential_keys[0]
142+
return try_functions(*(decoding_fn for _, decoding_fn in matching_entries.items()))
164143
else:
165-
raise ValueError(
166-
f"Multiple decoding functions registered for a type {t}: {potential_keys} \n"
167-
f"This could be a bug, but try to use different names for each type, or add the "
168-
f"modules they come from as a prefix, perhaps?"
169-
)
144+
# Try to evaluate the string annotation.
145+
t = evaluate_string_annotation(type_annotation)
146+
147+
elif is_forward_ref(type_annotation):
148+
forward_arg: str = get_forward_arg(type_annotation)
149+
# Recurse until we've resolved the forward reference.
150+
return get_decoding_fn(forward_arg)
151+
152+
else:
153+
t = type_annotation
154+
155+
logger.debug(f"{type_annotation!r} -> {t!r}")
156+
157+
# T should now be a type or one of the objects from the typing module.
170158

171159
if t in _decoding_fns:
172160
# The type has a dedicated decoding function.
173161
return _decoding_fns[t]
174162

163+
if is_dataclass_type(t):
164+
return partial(from_dict, t)
165+
175166
if t is Any:
176167
logger.debug(f"Decoding an Any type: {t}")
177168
return no_op
@@ -214,31 +205,6 @@ def _get_potential_keys(annotation: str) -> List[str]:
214205
logger.debug(f"Decoding an Enum field: {t}")
215206
return decode_enum(t)
216207

217-
from .serializable import SerializableMixin, get_dataclass_types_from_forward_ref
218-
219-
if is_forward_ref(t):
220-
dcs = get_dataclass_types_from_forward_ref(t)
221-
if len(dcs) == 1:
222-
dc = dcs[0]
223-
return dc.from_dict
224-
if len(dcs) > 1:
225-
logger.warning(
226-
RuntimeWarning(
227-
f"More than one potential Serializable dataclass was found with a name matching "
228-
f"the type annotation {t}. This will simply try each one, and return the "
229-
f"first one that works. Potential classes: {dcs}"
230-
)
231-
)
232-
return try_functions(*[partial(dc.from_dict, drop_extra_fields=False) for dc in dcs])
233-
else:
234-
# No idea what the forward ref refers to!
235-
logger.warning(
236-
f"Unable to find a dataclass that matches the forward ref {t} inside the "
237-
f"registered {SerializableMixin} subclasses. Leaving the value as-is."
238-
f"(Consider using Serializable or FrozenSerializable as a base class?)."
239-
)
240-
return no_op
241-
242208
if is_typevar(t):
243209
bound = get_bound(t)
244210
logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
@@ -256,31 +222,31 @@ def _get_potential_keys(annotation: str) -> List[str]:
256222
return try_constructor(t)
257223

258224

259-
def _register(t: Type, func: Callable) -> None:
225+
def _register(t: type, func: Callable) -> None:
260226
if t not in _decoding_fns:
261227
# logger.debug(f"Registering the type {t} with decoding function {func}")
262228
_decoding_fns[t] = func
263229

264230

265-
def register_decoding_fn(some_type: Type[T], function: Callable[[Any], T]) -> None:
231+
def register_decoding_fn(some_type: type[T], function: Callable[[Any], T]) -> None:
266232
"""Register a decoding function for the type `some_type`."""
267233
_register(some_type, function)
268234

269235

270-
def decode_optional(t: Type[T]) -> Callable[[Optional[Any]], Optional[T]]:
236+
def decode_optional(t: type[T]) -> Callable[[Any | None], T | None]:
271237
decode = get_decoding_fn(t)
272238

273-
def _decode_optional(val: Optional[Any]) -> Optional[T]:
239+
def _decode_optional(val: Any | None) -> T | None:
274240
return val if val is None else decode(val)
275241

276242
return _decode_optional
277243

278244

279-
def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], Union[T, Any]]:
245+
def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], T | Any]:
280246
"""Tries to use the functions in succession, else returns the same value unchanged."""
281247

282-
def _try_functions(val: Any) -> Union[T, Any]:
283-
e: Optional[Exception] = None
248+
def _try_functions(val: Any) -> T | Any:
249+
e: Exception | None = None
284250
for func in funcs:
285251
try:
286252
return func(val)
@@ -293,30 +259,30 @@ def _try_functions(val: Any) -> Union[T, Any]:
293259
return _try_functions
294260

295261

296-
def decode_union(*types: Type[T]) -> Callable[[Any], Union[T, Any]]:
262+
def decode_union(*types: type[T]) -> Callable[[Any], T | Any]:
297263
types = list(types)
298264
optional = type(None) in types
299265
# Partition the Union into None and non-None types.
300266
while type(None) in types:
301267
types.remove(type(None))
302268

303-
decoding_fns: List[Callable[[Any], T]] = [
269+
decoding_fns: list[Callable[[Any], T]] = [
304270
decode_optional(t) if optional else get_decoding_fn(t) for t in types
305271
]
306272
# Try using each of the non-None types, in succession. Worst case, return the value.
307273
return try_functions(*decoding_fns)
308274

309275

310-
def decode_list(t: Type[T]) -> Callable[[List[Any]], List[T]]:
276+
def decode_list(t: type[T]) -> Callable[[list[Any]], list[T]]:
311277
decode_item = get_decoding_fn(t)
312278

313-
def _decode_list(val: List[Any]) -> List[T]:
279+
def _decode_list(val: list[Any]) -> list[T]:
314280
return [decode_item(v) for v in val]
315281

316282
return _decode_list
317283

318284

319-
def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...]]:
285+
def decode_tuple(*tuple_item_types: type[T]) -> Callable[[list[T]], tuple[T, ...]]:
320286
"""Makes a parsing function for creating tuples.
321287
322288
Can handle tuples with different item types, for instance:
@@ -338,7 +304,7 @@ def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...
338304
# Note, if there are more values than types in the tuple type, then the
339305
# last type is used.
340306

341-
def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
307+
def _decode_tuple(val: tuple[Any, ...]) -> tuple[T, ...]:
342308
if has_ellipsis:
343309
return tuple(decoding_fn(v) for v in val)
344310
else:
@@ -347,7 +313,7 @@ def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
347313
return _decode_tuple
348314

349315

350-
def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
316+
def decode_set(item_type: type[T]) -> Callable[[list[T]], set[T]]:
351317
"""Makes a parsing function for creating sets with items of type `item_type`.
352318
353319
Args:
@@ -359,13 +325,13 @@ def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
359325
# Get the parse fn for a list of items of type `item_type`.
360326
parse_list_fn = decode_list(item_type)
361327

362-
def _decode_set(val: List[Any]) -> Set[T]:
328+
def _decode_set(val: list[Any]) -> set[T]:
363329
return set(parse_list_fn(val))
364330

365331
return _decode_set
366332

367333

368-
def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], Dict[K, V]]:
334+
def decode_dict(K_: type[K], V_: type[V]) -> Callable[[list[tuple[Any, Any]]], dict[K, V]]:
369335
"""Creates a decoding function for a dict type. Works with OrderedDict too.
370336
371337
Args:
@@ -379,8 +345,8 @@ def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], D
379345
decode_k = get_decoding_fn(K_)
380346
decode_v = get_decoding_fn(V_)
381347

382-
def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V]:
383-
result: Dict[K, V] = {}
348+
def _decode_dict(val: dict[Any, Any] | list[tuple[Any, Any]]) -> dict[K, V]:
349+
result: dict[K, V] = {}
384350
if isinstance(val, list):
385351
result = OrderedDict()
386352
items = val
@@ -399,7 +365,7 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V
399365
return _decode_dict
400366

401367

402-
def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]:
368+
def decode_enum(item_type: type[Enum]) -> Callable[[str], Enum]:
403369
"""
404370
Creates a decoding function for an enum type.
405371
@@ -428,7 +394,7 @@ def no_op(v: T) -> T:
428394
return v
429395

430396

431-
def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]:
397+
def try_constructor(t: type[T]) -> Callable[[Any], T | Any]:
432398
"""Tries to use the type as a constructor. If that fails, returns the value as-is.
433399
434400
Args:

simple_parsing/helpers/serialization/serializable.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,10 @@ class SimpleSerializable(SerializableMixin, decode_into_subclasses=True):
338338
S = TypeVar("S", bound=SerializableMixin)
339339

340340

341-
def get_dataclass_types_from_forward_ref(
341+
def get_serializable_dataclass_types_from_forward_ref(
342342
forward_ref: type, serializable_base_class: type[S] = SerializableMixin
343343
) -> list[type[S]]:
344+
"""Gets all the subclasses of `serializable_base_class` that have the same name as the argument of this forward reference annotation."""
344345
arg = get_forward_arg(forward_ref)
345346
potential_classes: list[type] = []
346347
for serializable_class in serializable_base_class.subclasses:

simple_parsing/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ClassVar,
2525
Container,
2626
Dict,
27+
ForwardRef,
2728
Iterable,
2829
List,
2930
Mapping,
@@ -50,12 +51,12 @@ def get_bound(t):
5051
raise TypeError(f"type is not a `TypeVar`: {t}")
5152

5253

53-
def is_forward_ref(t):
54+
def is_forward_ref(t) -> TypeGuard[typing.ForwardRef]:
5455
return isinstance(t, typing.ForwardRef)
5556

5657

57-
def get_forward_arg(fr):
58-
return getattr(fr, "__forward_arg__", None)
58+
def get_forward_arg(fr: ForwardRef) -> str:
59+
return getattr(fr, "__forward_arg__")
5960

6061

6162
logger = getLogger(__name__)

0 commit comments

Comments
 (0)