6
6
from pydantic import BaseModel
7
7
8
8
9
- def _get_base_type (annotation : type [Any ]) -> type [Any ]:
9
+ def _get_base_type (annotation : Type [Any ]) -> Type [Any ]:
10
10
if getattr (annotation , "__origin__" , None ) is Literal :
11
11
assert hasattr (annotation , "__args__" ) and len (annotation .__args__ ) >= 1 # type: ignore
12
12
return type (annotation .__args__ [0 ]) # type: ignore
13
- if getattr (annotation , "__origin__" , None ) is Union :
13
+ elif getattr (annotation , "__origin__" , None ) is Union :
14
14
assert hasattr (annotation , "__args__" ) and len (annotation .__args__ ) >= 1 # type: ignore
15
- non_optional_args : list [ type [Any ]] = [
15
+ non_optional_args : List [ Type [Any ]] = [
16
16
arg for arg in annotation .__args__ if arg is not type (None ) # type: ignore
17
17
]
18
18
if non_optional_args :
19
19
return _get_base_type (non_optional_args [0 ])
20
20
elif (
21
21
getattr (annotation , "__origin__" , None ) is list
22
- or getattr (annotation , "__origin__" , None ) is list
22
+ or getattr (annotation , "__origin__" , None ) is List
23
23
):
24
24
assert hasattr (annotation , "__args__" ) and len (annotation .__args__ ) >= 1 # type: ignore
25
25
return _get_base_type (annotation .__args__ [0 ]) # type: ignore
26
26
return annotation
27
27
28
28
29
- def _contains_list_type (annotation : type [Any ] | None ) -> bool :
29
+ def _contains_list_type (annotation : Type [Any ] | None ) -> bool :
30
30
origin = getattr (annotation , "__origin__" , None )
31
31
32
- if origin is list or origin is list :
32
+ if origin is list or origin is List :
33
33
return True
34
- if origin in (Literal , Union ):
34
+ elif origin in (Literal , Union ):
35
35
return any (_contains_list_type (arg ) for arg in annotation .__args__ ) # type: ignore
36
- return False
36
+ else :
37
+ return False
37
38
38
39
39
40
def _parse_bool_arg (arg : str | bytes | bool ) -> bool :
@@ -47,12 +48,13 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
47
48
48
49
if arg_str in true_values :
49
50
return True
50
- if arg_str in false_values :
51
+ elif arg_str in false_values :
51
52
return False
52
- raise ValueError (f"Invalid boolean argument: { arg } " )
53
+ else :
54
+ raise ValueError (f"Invalid boolean argument: { arg } " )
53
55
54
56
55
- def add_args_from_model (parser : argparse .ArgumentParser , model : type [BaseModel ]):
57
+ def add_args_from_model (parser : argparse .ArgumentParser , model : Type [BaseModel ]):
56
58
"""Add arguments from a pydantic model to an argparse parser."""
57
59
58
60
for name , field in model .model_fields .items ():
@@ -80,7 +82,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel])
80
82
)
81
83
82
84
83
- T = TypeVar ("T" , bound = type [BaseModel ])
85
+ T = TypeVar ("T" , bound = Type [BaseModel ])
84
86
85
87
86
88
def parse_model_from_args (model : T , args : argparse .Namespace ) -> T :
@@ -90,5 +92,5 @@ def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
90
92
k : v
91
93
for k , v in vars (args ).items ()
92
94
if v is not None and k in model .model_fields
93
- },
95
+ }
94
96
)
0 commit comments