Skip to content

Commit 071c6f2

Browse files
committed
Fix a few issues to address some edge cases
1 parent bd6fc1b commit 071c6f2

File tree

5 files changed

+81
-85
lines changed

5 files changed

+81
-85
lines changed

src/openapi_test_client/libraries/api/api_client_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def update_existing_endpoints(target_api_class: type[APIBase] = api_class) -> No
363363
endpoint_spec.get("summary")
364364
or endpoint_spec.get("description")
365365
or "No summary or description is available for this API"
366-
)
366+
).replace('"', '\\"')
367367
is_deprecated_api = endpoint_spec.get("deprecated", False)
368368
is_public_api = endpoint_spec.get("security") == []
369369

src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py

+58-55
Original file line numberDiff line numberDiff line change
@@ -140,66 +140,69 @@ def _parse_parameter_objects(
140140
https://swagger.io/specification/#parameter-object
141141
"""
142142
for param_obj in deepcopy(parameter_objects):
143-
param_name = param_obj["name"]
144-
try:
145-
param_location = param_obj["in"]
146-
param_def = ParamDef.from_param_obj(param_obj)
147-
# In case path parameters are incorrectly documented as required: false, we force make them required as path
148-
# parameters will always be required for our client
149-
is_required = True if param_location == "path" else None
150-
param_type_annotation = param_type_util.resolve_type_annotation(
151-
param_name, param_def, _is_required=is_required
152-
)
153-
if param_location in ["header", "cookies"]:
154-
# We currently don't support these
155-
continue
156-
elif param_location == "path":
157-
if param_name not in [x[0] for x in path_param_fields]:
158-
# Handle duplicates. Some API specs incorrectly document duplicated parameters
159-
path_param_fields.append(DataclassModelField(param_name, param_type_annotation))
160-
elif param_location == "query":
161-
if method.upper() != "GET":
162-
# Annotate query params for non GET endpoints
163-
param_type_annotation = param_type_util.annotate_type(param_type_annotation, "query")
164-
165-
if "schema" in param_obj:
166-
# defined as model. We unpack the model details
167-
schema_obj = param_obj["schema"]
168-
if "items" in schema_obj:
169-
schema_obj = schema_obj["items"]
170-
171-
if "properties" in schema_obj:
172-
properties = schema_obj["properties"]
173-
174-
for k, v in properties.items():
175-
if "name" not in properties[k]:
176-
properties[k]["name"] = k
177-
properties[k]["in"] = param_location
178-
179-
# Replace the param objects and parse it again
180-
parameter_objects.clear()
181-
parameter_objects.extend(properties.values())
182-
_parse_parameter_objects(
183-
method, parameter_objects, path_param_fields, body_or_query_param_fields
184-
)
143+
# NOTE: param_obj will be empty here if its $ref wasn't successfully resolved. We will ignore these
144+
if param_obj:
145+
param_name = ""
146+
try:
147+
param_name = param_obj["name"]
148+
param_location = param_obj["in"]
149+
param_def = ParamDef.from_param_obj(param_obj)
150+
# In case path parameters are incorrectly documented as required: false, we force make them required as
151+
# path parameters will always be required for our client
152+
is_required = True if param_location == "path" else None
153+
param_type_annotation = param_type_util.resolve_type_annotation(
154+
param_name, param_def, _is_required=is_required
155+
)
156+
if param_location in ["header", "cookies"]:
157+
# We currently don't support these
158+
continue
159+
elif param_location == "path":
160+
if param_name not in [x[0] for x in path_param_fields]:
161+
# Handle duplicates. Some API specs incorrectly document duplicated parameters
162+
path_param_fields.append(DataclassModelField(param_name, param_type_annotation))
163+
elif param_location == "query":
164+
if method.upper() != "GET":
165+
# Annotate query params for non GET endpoints
166+
param_type_annotation = param_type_util.annotate_type(param_type_annotation, "query")
167+
168+
if "schema" in param_obj:
169+
# defined as model. We unpack the model details
170+
schema_obj = param_obj["schema"]
171+
if "items" in schema_obj:
172+
schema_obj = schema_obj["items"]
173+
174+
if "properties" in schema_obj:
175+
properties = schema_obj["properties"]
176+
177+
for k, v in properties.items():
178+
if "name" not in properties[k]:
179+
properties[k]["name"] = k
180+
properties[k]["in"] = param_location
181+
182+
# Replace the param objects and parse it again
183+
parameter_objects.clear()
184+
parameter_objects.extend(properties.values())
185+
_parse_parameter_objects(
186+
method, parameter_objects, path_param_fields, body_or_query_param_fields
187+
)
188+
else:
189+
_add_body_or_query_param_field(
190+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
191+
)
192+
185193
else:
186194
_add_body_or_query_param_field(
187195
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
188196
)
189-
190197
else:
191-
_add_body_or_query_param_field(
192-
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
193-
)
194-
else:
195-
raise NotImplementedError(f"Unsupported param 'in': {param_location}")
196-
except Exception:
197-
logger.error(
198-
"Encountered an error while processing a param object in 'parameters':\n"
199-
f"- param name: {param_name}\n"
200-
f"- param object: {param_obj}"
201-
)
202-
raise
198+
raise NotImplementedError(f"Unsupported param 'in': {param_location}")
199+
except Exception:
200+
logger.error(
201+
"Encountered an error while processing a param object in 'parameters':\n"
202+
f"- param name: {param_name}\n"
203+
f"- param object: {param_obj}"
204+
)
205+
raise
203206

204207

205208
def _parse_request_body_object(

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+19-26
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ def generate_model_name(field_name: str, field_type: str | Any) -> str:
9090
:param field_name: Dataclass field name
9191
:param field_type: OpenAPI parameter type or dataclass field type
9292
"""
93-
model_name = generate_class_name(field_name)
93+
model_name = generate_class_name(clean_model_field_name(field_name))
9494
# NOTE: You may need to add a custom blacklist/rules for your app in here to ignore words that inflect library
9595
# doesn't handle well.
9696
# Eg. The word "bps" (Bits per second) is not a plural word, but inflect thinks it is and incorrectly generates
9797
# its singular noun as "bp"
9898
if param_type_util.is_type_of(field_type, list) and (singular_noun := inflect.engine().singular_noun(model_name)):
9999
# Change the plural model name to the singular word
100-
model_name = singular_noun
100+
model_name = cast(str, singular_noun)
101101

102102
# Adjust the model name if it happens to conflict with class names we might import, or with the field_name itself
103103
if model_name in [*get_reserved_model_names(), field_name]:
@@ -320,6 +320,18 @@ def visit(model_name: str) -> None:
320320
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))
321321

322322

323+
@lru_cache
324+
def clean_model_field_name(name: str) -> str:
325+
"""Returns an alias name if the given name is illegal as a model field name"""
326+
name = clean_obj_name(name)
327+
# NOTE: The escaping of kwargs is already is handled in endpoint model
328+
reserved_param_names = ["self", "validate", *get_supported_request_parameters()]
329+
if name in get_reserved_model_names() + reserved_param_names:
330+
# The field name conflicts with one of reserved names
331+
name += "_"
332+
return name
333+
334+
323335
def alias_illegal_model_field_names(location: str, model_fields: list[DataclassModelField]) -> None:
324336
"""Clean illegal model field name and annotate the field type with Alias class
325337
@@ -335,33 +347,14 @@ def make_alias(name: str, param_type: Any) -> str:
335347
):
336348
return name
337349
else:
338-
name = clean_obj_name(name)
339-
# NOTE: The escaping of kwargs is already is handled in endpoint model
340-
reserved_param_names = ["self", "validate", *get_supported_request_parameters()]
341-
if name in get_reserved_model_names() + reserved_param_names:
342-
# The field name conflicts with one of reserved names
343-
name += "_"
344-
350+
name = clean_model_field_name(name)
345351
if param_models := get_param_model(param_type):
346-
# There seems to be an issue with the `Annotated` cache behavior. AttributeError will be thrown on
347-
# importing the model when the following conditions are all met:
348-
# - The model field is annotated with `Annotated`
349-
# - The origin type of `Annotated` is another param model, or union of param models (nested model)
350-
# - The model field name is identical to one of the annotated model names
351-
# eg.
352-
#
353-
# @dataclass
354-
# class NestedModel(ParamModel):
355-
# param: str = Unset
356-
#
357-
# @dataclass
358-
# class Model(ParamModel):
359-
# NestedModel: Annotated[NestedModel, "test"] = Unset
360-
#
352+
# There seems to be some known issues when the field name clashes with the type annotation name.
353+
# We change the field name in this case
354+
# eg. https://docs.pydantic.dev/2.10/errors/usage_errors/#unevaluable-type-annotation
361355
if not isinstance(param_models, list):
362356
param_models = [param_models]
363-
if any(name == (m.__forward_arg__ if isinstance(m, ForwardRef) else m.__name__) for m in param_models):
364-
# This meets the above issue conditions
357+
if any(name == get_param_model_name(m) for m in param_models):
365358
name += "_"
366359
return name
367360

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
INT_PARAM_TYPES = ["integer", "int", "int64", "number"]
3434
BOOL_PARAM_TYPES = ["boolean", "bool"]
3535
LIST_PARAM_TYPES = ["array"]
36-
NULL_PARAM_TYPES = ["null"]
36+
NULL_PARAM_TYPES = ["null", None]
3737

3838

3939
def get_type_annotation_as_str(tp: Any) -> str:
@@ -199,7 +199,7 @@ def resolve(param_type: str | Sequence[str], param_format: str | None = None) ->
199199
# Optional parameter
200200
type_annotation = generate_optional_type(type_annotation)
201201

202-
if num_optional_types := repr(type_annotation).count("Optional"):
202+
if num_optional_types := repr(type_annotation).count("Optional["):
203203
# Sanity check for Optional type. If it is annotated with `Optional`, we want it to appear as the origin type
204204
# only. If this check fails, it means the logic is broke somewhere
205205
if num_optional_types > 1:

src/openapi_test_client/libraries/common/misc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_module_name_by_file_path(file_path: Path) -> str:
5454
assert is_external_project()
5555
# Accessing the package file from an external location
5656
file_path_from_package_dir = _PACKAGE_DIR.name + str(file_path).rsplit(_PACKAGE_DIR.name)[-1]
57-
return file_path_from_package_dir.replace(os.sep, ".").replace(".py", "")
57+
return file_path_from_package_dir.replace(os.sep, ".").removesuffix(".py")
5858

5959

6060
def import_module_from_file_path(file_path: Path) -> ModuleType:

0 commit comments

Comments
 (0)