1
- import functools
2
1
import sys
3
2
import typing
4
- from enum import Enum
5
- from functools import partial
6
- from pathlib import PurePath
7
3
from typing import Any
8
- from typing import Callable
9
4
from typing import Dict
10
5
from typing import Iterable
11
6
from typing import List
12
7
from typing import Literal
13
- from typing import Optional
8
+ from typing import Mapping
9
+ from typing import Protocol
14
10
from typing import Tuple
15
11
from typing import Type
16
12
from typing import Union
17
- from typing import Literal
18
- from typing import Protocol
19
- if sys .version_info >= (3 , 12 ):
13
+
14
+ if sys .version_info >= (3 , 10 ):
20
15
from typing import TypeAlias
21
16
else :
22
17
from typing_extensions import TypeAlias
45
40
from attr import fields_dict as get_attr_fields_dict
46
41
47
42
Attribute = attr .Attribute
48
-
43
+ # dataclasses and attr have internal tokens for missing values, join into a set so that we can
44
+ # check if a value is missing without knowing the type of backing class
49
45
MISSING = {DATACLASSES_MISSING , ATTR_NOTHING }
50
46
except ImportError :
51
47
_use_attr = False
52
48
attr = None
53
49
ATTR_NOTHING = None
54
- Attribute = TypeVar ("Attribute" , bound = object ) # type: ignore
50
+ Attribute = TypeVar ("Attribute" , bound = object ) # type: ignore[misc, assignment]
51
+
52
+ # define empty placeholders for getting attr fields as a tuple or dict. They will never be
53
+ # called because the import failed; but they're here to ensure that the function is defined in
54
+ # sections of code that don't know if the import was successful or not.
55
55
56
- def get_attr_fields (cls : type ) -> Tuple [dataclasses .Field , ...]: # type: ignore
56
+ def get_attr_fields (cls : type ) -> Tuple [dataclasses .Field , ...]: # type: ignore[misc]
57
+ """Get tuple of fields for attr class. attrs isn't imported so return empty tuple."""
57
58
return ()
58
59
59
- def get_attr_fields_dict (cls : type ) -> Dict [str , dataclasses .Field ]: # type: ignore
60
+ def get_attr_fields_dict (cls : type ) -> Dict [str , dataclasses .Field ]: # type: ignore[misc]
61
+ """Get dict of name->field for attr class. attrs isn't imported so return empty dict."""
60
62
return {}
61
63
64
+ # for consistency with successful import of attr, create a set for missing values
62
65
MISSING = {DATACLASSES_MISSING }
63
66
64
67
if TYPE_CHECKING :
@@ -73,22 +76,33 @@ class DataclassesProtocol(Protocol):
73
76
from attr import AttrsInstance
74
77
else :
75
78
76
- class AttrsInstance (Protocol ): # type: ignore
79
+ class AttrsInstance (Protocol ): # type: ignore[no-redef]
77
80
__attrs_attrs__ : Dict [str , Any ]
78
81
79
82
80
- def is_attr_class (cls : type ) -> bool : # type: ignore
83
+ def is_attr_class (cls : type ) -> bool : # type: ignore[arg-type]
84
+ """Return True if the class is an attr class, and False otherwise"""
81
85
return hasattr (cls , "__attrs_attrs__" )
82
86
83
87
84
- MISSING_OR_NONE = {* MISSING , None }
85
- DataclassesOrAttrClass = Union [DataclassesProtocol , AttrsInstance ]
86
- FieldType : TypeAlias = Union [dataclasses .Field , attr .Attribute ]
88
+ _MISSING_OR_NONE = {* MISSING , None }
89
+ """Set of values that are considered missing or None for dataclasses or attr classes"""
90
+ _DataclassesOrAttrClass : TypeAlias = Union [DataclassesProtocol , AttrsInstance ]
91
+ """
92
+ TypeAlias for dataclasses or attr classes. Mostly nonsense because they are not true types, they
93
+ are traits, but there is no python trait-tester.
94
+ """
95
+ FieldType : TypeAlias = Union [dataclasses .Field , Attribute ]
96
+ """
97
+ TypeAlias for dataclass Fields or attrs Attributes. It will correspond to the correct type for the
98
+ corresponding _DataclassesOrAttrClass
99
+ """
87
100
88
101
89
- def get_dataclasses_fields_dict (
102
+ def _get_dataclasses_fields_dict (
90
103
class_or_instance : Union [DataclassesProtocol , Type [DataclassesProtocol ]],
91
104
) -> Dict [str , dataclasses .Field ]:
105
+ """Get a dict from field name to Field for a dataclass class or instance."""
92
106
return {field .name : field for field in get_dataclasses_fields (class_or_instance )}
93
107
94
108
@@ -372,46 +386,49 @@ def get_parser() -> partial:
372
386
return parser
373
387
374
388
375
- def get_fields_dict (cls : Type [DataclassesOrAttrClass ]) -> Dict [str , FieldType ]:
376
- """
377
- Get the fields dict from either a dataclasses or attr dataclass.
389
+ def get_fields_dict (
390
+ cls : Union [_DataclassesOrAttrClass , Type [_DataclassesOrAttrClass ]]
391
+ ) -> Mapping [str , FieldType ]:
392
+ """Get the fields dict from either a dataclasses or attr dataclass (or instance)"""
393
+ if is_dataclasses_class (cls ):
394
+ return _get_dataclasses_fields_dict (cls ) # type: ignore[arg-type]
395
+ elif is_attr_class (cls ): # type: ignore[arg-type]
396
+ return get_attr_fields_dict (cls ) # type: ignore[arg-type]
397
+ else :
398
+ raise TypeError ("cls must a dataclasses or attr class" )
378
399
379
- Combine results in case someone chooses to mix them through inheritance.
380
- """
381
- if not (is_dataclasses_class (cls ) or is_attr_class (cls )):
382
- raise ValueError ("cls must a dataclasses or attr class" )
383
- return {
384
- ** (get_dataclasses_fields_dict (cls ) if is_dataclasses_class (cls ) else {}),
385
- ** (get_attr_fields_dict (cls ) if is_attr_class (cls ) else {}), # type: ignore
386
- }
387
-
388
-
389
- def get_fields (cls : Type [DataclassesOrAttrClass ]) -> Tuple [FieldType , ...]:
390
- if not (is_dataclasses_class (cls ) or is_attr_class (cls )):
391
- raise ValueError ("cls must a dataclasses or attr class" )
392
- return (get_dataclasses_fields (cls ) if is_dataclasses_class (cls ) else ()) + (
393
- get_attr_fields (cls ) if is_attr_class (cls ) else () # type: ignore
394
- )
400
+
401
+ def get_fields (
402
+ cls : Union [_DataclassesOrAttrClass , Type [_DataclassesOrAttrClass ]]
403
+ ) -> Tuple [FieldType , ...]:
404
+ """Get the fields tuple from either a dataclasses or attr dataclass (or instance)"""
405
+ if is_dataclasses_class (cls ):
406
+ return get_dataclasses_fields (cls ) # type: ignore[arg-type]
407
+ elif is_attr_class (cls ): # type: ignore[arg-type]
408
+ return get_attr_fields (cls ) # type: ignore[arg-type]
409
+ else :
410
+ raise TypeError ("cls must a dataclasses or attr class" )
395
411
396
412
397
- AttrFromType = TypeVar ("AttrFromType" )
413
+ _AttrFromType = TypeVar ("_AttrFromType" )
414
+ """TypeVar to allow attr_from to be used with either an attr class or a dataclasses class"""
398
415
399
416
400
417
def attr_from (
401
- cls : Type [AttrFromType ],
418
+ cls : Type [_AttrFromType ],
402
419
kwargs : Dict [str , str ],
403
420
parsers : Optional [Dict [type , Callable [[str ], Any ]]] = None ,
404
- ) -> AttrFromType :
405
- """Builds an attr class from key-word arguments
421
+ ) -> _AttrFromType :
422
+ """Builds an attr or dataclasses class from key-word arguments
406
423
407
424
Args:
408
- cls: the attr class to be built
425
+ cls: the attr or dataclasses class to be built
409
426
kwargs: a dictionary of keyword arguments
410
427
parsers: a dictionary of parser functions to apply to specific types
411
428
412
429
"""
413
430
return_values : Dict [str , Any ] = {}
414
- for attribute in get_fields (cls ): # type: ignore
431
+ for attribute in get_fields (cls ): # type: ignore[arg-type]
415
432
return_value : Any
416
433
if attribute .name in kwargs :
417
434
str_value : str = kwargs [attribute .name ]
@@ -447,7 +464,7 @@ def attr_from(
447
464
set_value
448
465
), f"Do not know how to convert string to { attribute .type } for value: { str_value } "
449
466
else : # no value, check for a default
450
- assert attribute .default is not None or attribute_is_optional (
467
+ assert attribute .default is not None or _attribute_is_optional (
451
468
attribute
452
469
), f"No value given and no default for attribute `{ attribute .name } `"
453
470
return_value = attribute .default
@@ -460,13 +477,13 @@ def attr_from(
460
477
return cls (** return_values )
461
478
462
479
463
- def attribute_is_optional (attribute : FieldType ) -> bool :
480
+ def _attribute_is_optional (attribute : FieldType ) -> bool :
464
481
"""Returns True if the attribute is optional, False otherwise"""
465
482
return typing .get_origin (attribute .type ) is Union and isinstance (
466
483
None , typing .get_args (attribute .type )
467
484
)
468
485
469
486
470
- def attribute_has_default (attribute : FieldType ) -> bool :
487
+ def _attribute_has_default (attribute : FieldType ) -> bool :
471
488
"""Returns True if the attribute has a default value, False otherwise"""
472
- return attribute .default not in MISSING_OR_NONE or attribute_is_optional (attribute )
489
+ return attribute .default not in _MISSING_OR_NONE or _attribute_is_optional (attribute )
0 commit comments