34
34
logger = get_logger (__name__ )
35
35
36
36
37
- def is_param_model (annotated_type : Any ) -> bool :
38
- """Check if the given annotated type is for a custom model
37
+ def has_param_model (annotated_type : Any ) -> bool :
38
+ """Check if the given annotated type contains a custom param model
39
39
40
- :param annotated_type: Annotated type for a field to check whether it is a param model or not
40
+ :param annotated_type: Annotated type for a field to check whether it contains a param model or not
41
41
"""
42
42
43
43
def _is_param_model (obj : Any ) -> bool :
44
- return inspect .isclass (obj ) and ( issubclass (obj , ParamModel ) )
44
+ return inspect .isclass (obj ) and issubclass (obj , ParamModel )
45
45
46
46
inner_type = param_type_util .get_inner_type (annotated_type )
47
- if get_origin (inner_type ) in [ Union , UnionType ] :
47
+ if param_type_util . is_union_type (inner_type ):
48
48
return any (_is_param_model (o ) for o in get_args (inner_type ))
49
49
else :
50
50
return _is_param_model (inner_type )
@@ -56,9 +56,9 @@ def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None
56
56
:param annotated_type: Annotated type
57
57
"""
58
58
inner_type = param_type_util .get_inner_type (annotated_type )
59
- if is_param_model (inner_type ):
60
- if get_origin (inner_type ) in [ Union , UnionType ] :
61
- return list ( get_args (inner_type ))
59
+ if has_param_model (inner_type ):
60
+ if param_type_util . is_union_type (inner_type ):
61
+ return [ x for x in get_args (inner_type ) if has_param_model ( x )]
62
62
else :
63
63
return inner_type
64
64
@@ -160,7 +160,7 @@ def generate_imports_code(obj_type: Any):
160
160
[generate_imports_code (m ) for m in get_args (obj_type )]
161
161
else :
162
162
raise NotImplementedError (f"Unsupported typing origin: { typing_origin } " )
163
- elif is_param_model (obj_type ):
163
+ elif has_param_model (obj_type ):
164
164
if not exclude_nested_models :
165
165
api_cls_module , model_file_name = api_class .__module__ .rsplit ("." , 1 )
166
166
module_and_name_pairs .append (
@@ -226,7 +226,7 @@ def get_param_models(model: type[EndpointModel | ParamModel], recursive: bool =
226
226
227
227
def collect_param_models (model : type [EndpointModel | ParamModel ]):
228
228
for field_name , field_obj in model .__dataclass_fields__ .items ():
229
- if is_param_model (field_obj .type ):
229
+ if has_param_model (field_obj .type ):
230
230
model_name = generate_model_name (field_name , field_obj .type )
231
231
param_def = ParamDef .from_param_obj (field_obj .metadata )
232
232
param_model = create_model_from_param_def (model_name , param_def )
0 commit comments