Skip to content

Commit 75d3437

Browse files
committed
Fix handling union types inside typing.Optional
1 parent 3bb1998 commit 75d3437

File tree

2 files changed

+23
-28
lines changed

2 files changed

+23
-28
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@
3434
logger = get_logger(__name__)
3535

3636

37-
def is_param_model(annotated_type: Any) -> bool:
38-
"""Check if the given annotated type is for a custom model
37+
def has_param_model(annotated_type: Any) -> bool:
38+
"""Check if the given annotated type contains a custom param model
3939
40-
:param annotated_type: Annotated type for a field to check whether it is a param model or not
40+
:param annotated_type: Annotated type for a field to check whether it contains a param model or not
4141
"""
4242

4343
def _is_param_model(obj: Any) -> bool:
44-
return inspect.isclass(obj) and (issubclass(obj, ParamModel))
44+
return inspect.isclass(obj) and issubclass(obj, ParamModel)
4545

4646
inner_type = param_type_util.get_inner_type(annotated_type)
47-
if get_origin(inner_type) in [Union, UnionType]:
47+
if param_type_util.is_union_type(inner_type):
4848
return any(_is_param_model(o) for o in get_args(inner_type))
4949
else:
5050
return _is_param_model(inner_type)
@@ -56,9 +56,9 @@ def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None
5656
:param annotated_type: Annotated type
5757
"""
5858
inner_type = param_type_util.get_inner_type(annotated_type)
59-
if is_param_model(inner_type):
60-
if get_origin(inner_type) in [Union, UnionType]:
61-
return list(get_args(inner_type))
59+
if has_param_model(inner_type):
60+
if param_type_util.is_union_type(inner_type):
61+
return [x for x in get_args(inner_type) if has_param_model(x)]
6262
else:
6363
return inner_type
6464

@@ -160,7 +160,7 @@ def generate_imports_code(obj_type: Any):
160160
[generate_imports_code(m) for m in get_args(obj_type)]
161161
else:
162162
raise NotImplementedError(f"Unsupported typing origin: {typing_origin}")
163-
elif is_param_model(obj_type):
163+
elif has_param_model(obj_type):
164164
if not exclude_nested_models:
165165
api_cls_module, model_file_name = api_class.__module__.rsplit(".", 1)
166166
module_and_name_pairs.append(
@@ -226,7 +226,7 @@ def get_param_models(model: type[EndpointModel | ParamModel], recursive: bool =
226226

227227
def collect_param_models(model: type[EndpointModel | ParamModel]):
228228
for field_name, field_obj in model.__dataclass_fields__.items():
229-
if is_param_model(field_obj.type):
229+
if has_param_model(field_obj.type):
230230
model_name = generate_model_name(field_name, field_obj.type)
231231
param_def = ParamDef.from_param_obj(field_obj.metadata)
232232
param_model = create_model_from_param_def(model_name, param_def)

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ def resolve(param_type: str, param_format: str | None = None):
159159
# code will be valid. So ignoring the warning here.
160160
type_annotation = Literal[*enum] # type: ignore
161161
elif isinstance(param_def, ParamDef.UnknownType):
162-
logger.exception(
162+
logger.warning(
163163
f"Unable to locate a parameter type for parameter '{param_name}'. Type '{Any}' will be applied.\n"
164-
f"Failed parameter object:\n{param_def.param_obj}"
164+
f"Failed parameter object: {param_def.param_obj}"
165165
)
166166
type_annotation = Any
167167
else:
@@ -193,42 +193,37 @@ def resolve(param_type: str, param_format: str | None = None):
193193
return type_annotation
194194

195195

196-
def get_inner_type(tp: Any, return_if_container_type: bool = False) -> Any | tuple[Any] | list[Any]:
196+
def get_inner_type(tp: Any, return_if_container_type: bool = False) -> Any | list[Any]:
197197
"""Get the inner type (=actual type) from the type annotation
198198
199199
eg:
200200
Optional[str] -> str
201+
Optional[str | int] -> str | int
201202
Annotated[str, "metadata"] -> str
202203
Literal[1,2,3] -> Literal[1,2,3]
203204
list[str] -> str (list[str] if return_if_container_type=True)
204-
str | int -> (str, int)
205+
str | int -> str | int
205206
206207
NOTE: This should also work with the combination of above
207208
208209
:param tp: Type annotation
209210
:param return_if_container_type: Consider container type like list and tuple as the inner type
210211
"""
211212
if origin_type := get_origin(tp):
212-
args = get_args(tp)
213+
if origin_type not in [Optional, Union, UnionType, Annotated, Literal, list, dict]:
214+
raise RuntimeError(f"Found unexpected origin type in '{tp}': {origin_type}")
215+
213216
if is_union_type(tp):
214-
if is_optional_type(tp):
215-
return get_inner_type(args[0])
216-
else:
217-
return args
217+
args_without_nonetype = [x for x in get_args(tp) if x is not NoneType]
218+
return generate_union_type([get_inner_type(x) for x in args_without_nonetype])
218219
elif origin_type is Annotated:
219220
return get_inner_type(tp.__origin__)
220-
elif origin_type in [list, tuple]:
221+
elif origin_type is list:
221222
if return_if_container_type:
222223
return tp
223224
else:
224-
if origin_type is list:
225-
return get_inner_type(args[0])
226-
else:
227-
return get_inner_type(args)
228-
else:
229-
return tp
230-
else:
231-
return tp
225+
return get_inner_type(get_args(tp)[0])
226+
return tp
232227

233228

234229
def replace_inner_type(tp: Any, new_type: Any, replace_container_type: bool = False) -> Any:

0 commit comments

Comments
 (0)