Skip to content

Commit bf9111a

Browse files
committed
Handle empty schema in requestBody.
Update logging around the generate/update script execution
1 parent 513f65a commit bf9111a

File tree

7 files changed

+62
-44
lines changed

7 files changed

+62
-44
lines changed

src/openapi_test_client/libraries/api/api_client_generator.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ def generate_api_class(
168168
base_class = _get_base_api_class(api_client)
169169
api_dir = app_client_dir / API_CLASS_DIR_NAME
170170
api_class_file_path = api_dir / f"{class_file_name.lower()}.py"
171-
logger.warning(
171+
logger.info(
172172
f"Generating a new API class file:\n"
173173
f"- API class name: {class_name} (tag={tag})\n"
174-
f"- file path: {api_class_file_path}"
174+
f"- file path: {api_class_file_path}",
175175
)
176176
if api_class_file_path.exists():
177177
raise RuntimeError(f"{api_class_file_path} exists")
@@ -301,13 +301,13 @@ def update_endpoint_functions(
301301
original_model_code = format_code(model_file_path.read_text(), remove_unused_imports=False)
302302
else:
303303
original_model_code = ""
304-
method = path = func_name = None
304+
method = path = func_name = ""
305305
defined_endpoints = []
306306
param_models = []
307307

308308
def update_existing_endpoints(target_api_class: type[APIBase] = api_class) -> None:
309309
"""Updated existing endpoint functions"""
310-
nonlocal modified_api_cls_code
310+
nonlocal modified_api_cls_code, method, path, func_name
311311
new_code = current_code = modified_api_cls_code
312312
api_spec_tags = set()
313313

@@ -351,7 +351,7 @@ def update_existing_endpoints(target_api_class: type[APIBase] = api_class) -> No
351351
if endpoint_function.endpoint.is_documented:
352352
err = f"{TAB}Not found: {method.upper()} {path} ({func_name})"
353353
print(color(err, color_code=ColorCodes.RED)) # noqa: T201
354-
else:
354+
elif verbose:
355355
msg = f"{TAB}Skipped undocumented endpoint: {method.upper()} {path} ({func_name})"
356356
print(msg) # noqa: T201
357357
continue
@@ -522,12 +522,13 @@ def update_missing_endpoints() -> None:
522522
original_api_cls_code = original_model_code = ""
523523
if not update_param_models_only:
524524
if api_cls_updated := (original_api_cls_code != modified_api_cls_code):
525-
if not is_new_api_class:
526-
msg = f"{TAB}Update{' required' if dry_run else 'd'}: {api_cls_file_path}"
527-
print(color(msg, color_code=ColorCodes.YELLOW)) # noqa:T201
528-
529525
if show_diff:
530526
# Print diff
527+
if is_new_api_class:
528+
logger.info(f"Generated API class code: {api_cls_file_path}")
529+
else:
530+
msg = f"{TAB}Update{' required' if dry_run else 'd'}: {api_cls_file_path}"
531+
print(color(msg, color_code=ColorCodes.YELLOW)) # noqa:T201
531532
diff_code(
532533
original_api_cls_code,
533534
modified_api_cls_code,
@@ -540,6 +541,8 @@ def update_missing_endpoints() -> None:
540541
api_cls_file_path.write_text(modified_api_cls_code)
541542

542543
if param_models:
544+
if verbose:
545+
logger.info(f"Checking param models for {api_class.__name__}...")
543546
modified_model_code = (
544547
f"from dataclasses import dataclass\n\nfrom {ParamModel.__module__} import {ParamModel.__name__}\n\n"
545548
)
@@ -551,9 +554,12 @@ def update_missing_endpoints() -> None:
551554

552555
if model_updated := (original_model_code != modified_model_code):
553556
# Print diff
554-
msg = f"{TAB}Update{' required' if dry_run else 'd'} (models): {model_file_path}"
555-
print(color(msg, color_code=ColorCodes.YELLOW)) # noqa:T201
556557
if show_diff:
558+
if is_new_api_class:
559+
logger.info(f"Generated param models: {model_file_path}")
560+
else:
561+
msg = f"{TAB}Update{' required' if dry_run else 'd'} (models): {model_file_path}"
562+
print(color(msg, color_code=ColorCodes.YELLOW)) # noqa:T201
557563
diff_code(
558564
original_model_code,
559565
modified_model_code,
@@ -565,12 +571,14 @@ def update_missing_endpoints() -> None:
565571
model_file_path.parent.mkdir(parents=True, exist_ok=True)
566572
model_file_path.write_text(modified_model_code)
567573
except Exception as e:
568-
# This should not happen
574+
# An error during code generation
569575
tb = traceback.format_exc()
570-
err = f"Failed to update {api_cls_file_path}:"
571-
if all([func_name, method, path]):
572-
err += f" {func_name} ({method} {path})"
573-
err += f"\n{tb})\n"
576+
err = f"Failed to update code:\n - File {api_cls_file_path}"
577+
if func_name:
578+
err += f"\n - Function: {func_name}"
579+
if method and path:
580+
err += f"\n - Endpoint: {method.upper()} {path}"
581+
err += f"\n - Error details:\n{tb}\n"
574582
print(color(err, color_code=ColorCodes.RED)) # noqa :T201
575583
if api_cls_updated and not dry_run:
576584
# revert back to original code
@@ -590,7 +598,7 @@ def generate_api_client(temp_api_client: OpenAPIClient, show_generated_code: boo
590598
:param temp_api_client: Temporary API client
591599
:param show_generated_code: Show generated client code
592600
"""
593-
logger.warning(f"Generating a new API client for {temp_api_client.app_name}")
601+
logger.info(f"Generating a new API client for {temp_api_client.app_name}")
594602
assert _is_temp_client(temp_api_client)
595603
app_name = temp_api_client.app_name
596604
base_api_class = generate_base_api_class(temp_api_client)

src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,21 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
4040
if parameter_objects := operation_obj.get("parameters"):
4141
_parse_parameter_objects(method, parameter_objects, path_param_fields, body_or_query_param_fields)
4242
if request_body := operation_obj.get("requestBody"):
43-
content_type = _parse_request_body_object(request_body, body_or_query_param_fields)
43+
try:
44+
content_type = _parse_request_body_object(request_body, body_or_query_param_fields)
45+
except Exception:
46+
logger.warning(f"Unable to parse the following requestBody obj:\n{request_body}")
47+
raise
48+
49+
expected_path_params = re.findall(r"{([^}]+)}", path)
50+
documented_path_params = [x[0] for x in path_param_fields]
51+
if undocumented_path_params := [x for x in expected_path_params if x not in documented_path_params]:
52+
logger.warning(f"{method.upper()} {path}: Found undocumented path parameters: {undocumented_path_params}")
53+
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
54+
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
55+
# body/query params if any
56+
path_param_fields = [DataclassModelField(x, str) for x in expected_path_params]
57+
body_or_query_param_fields = [x for x in body_or_query_param_fields if x[0] not in expected_path_params]
4458
else:
4559
# Generate model fields from the function signature
4660
sig = inspect.signature(endpoint_func._original_func)
@@ -54,19 +68,6 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
5468
# keyword arguments (body/query parameters)
5569
_add_body_or_query_param_field(body_or_query_param_fields, name, param_obj.annotation)
5670

57-
if hasattr(endpoint_func, "endpoint"):
58-
method = endpoint_func.endpoint.method
59-
path = endpoint_func.endpoint.path
60-
expected_path_params = re.findall(r"{([^}]+)}", path)
61-
documented_path_params = [x[0] for x in path_param_fields]
62-
if undocumented_path_params := [x for x in expected_path_params if x not in documented_path_params]:
63-
logger.warning(f"{method.upper()} {path}: Found undocumented path parameters: {undocumented_path_params}")
64-
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
65-
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
66-
# body/query params if any
67-
path_param_fields = [DataclassModelField(x, str) for x in expected_path_params]
68-
body_or_query_param_fields = [x for x in body_or_query_param_fields if x[0] not in expected_path_params]
69-
7071
# Address the case where a path param name conflicts with body/query param name
7172
for i, (field_name, field_type, _) in enumerate(path_param_fields):
7273
if field_name in [x[0] for x in body_or_query_param_fields]:
@@ -75,8 +76,9 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
7576
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
7677
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
7778
# When calling an endpoint function, the actual name will be automatically resolved in the payload/query parameters
78-
param_model_util.alias_illegal_model_field_names(model_name, path_param_fields)
79-
param_model_util.alias_illegal_model_field_names(model_name, body_or_query_param_fields)
79+
endpoint = f"{endpoint_func.method.upper()} {endpoint_func.path}"
80+
param_model_util.alias_illegal_model_field_names(endpoint, path_param_fields)
81+
param_model_util.alias_illegal_model_field_names(endpoint, body_or_query_param_fields)
8082

8183
fields = path_param_fields + body_or_query_param_fields
8284
return cast(
@@ -211,7 +213,7 @@ def _parse_request_body_object(
211213
# TODO: Support multiple content types
212214
content_type = next(iter(contents))
213215

214-
def parse_schema_obj(obj: dict[str, Any]) -> list[dict[str, Any]] | None:
216+
def parse_schema_obj(obj: dict[str, Any]) -> list[dict[str, Any] | list[dict[str, Any]]] | None:
215217
# This part has some variations, and sometimes not consistent
216218
if not (properties := obj.get("properties", {})):
217219
schema_type = obj.get("type")
@@ -235,8 +237,12 @@ def parse_schema_obj(obj: dict[str, Any]) -> list[dict[str, Any]] | None:
235237
):
236238
# The API directly takes data that is not form-encoded (eg. send tar binary data)
237239
properties = {"data": schema_obj}
240+
elif obj == {}:
241+
# Empty schema
242+
properties = {}
238243
else:
239-
raise NotImplementedError(f"Unsupported request body:\n{json.dumps(obj, indent=4, default=str)}")
244+
# An example from actual OpenAPI spec: {"content": {"application/json": {"schema": {"type": 'string"}}
245+
raise NotImplementedError(f"Unsupported schema obj:\n{json.dumps(obj, indent=4, default=str)}")
240246

241247
for param_name in properties:
242248
param_obj = properties[param_name]

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,10 @@ def visit(model_name: str) -> None:
313313
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))
314314

315315

316-
def alias_illegal_model_field_names(model_name: str, model_fields: list[DataclassModelField]) -> None:
316+
def alias_illegal_model_field_names(location: str, model_fields: list[DataclassModelField]) -> None:
317317
"""Clean illegal model field name and annotate the field type with Alias class
318318
319-
:param model_name: Model name
319+
:param location: Location where the field is seen. This is used for logging purpose
320320
:param model_fields: fields value to be passed to make_dataclass()
321321
"""
322322

@@ -362,7 +362,9 @@ def make_alias(name: str, param_type: Any) -> str:
362362
for i, model_field in enumerate(model_fields):
363363
if (alias_name := make_alias(model_field.name, model_field.type)) != model_field.name:
364364
if isinstance(model_field.default, Field) and model_field.default.metadata:
365-
logger.warning(f"Converted parameter name '{model_field.name}' to '{alias_name}'")
365+
logger.warning(
366+
f"[{location}]: The parameter name '{model_field.name}' was aliased to '{alias_name}'"
367+
)
366368
new_fields = (
367369
alias_name,
368370
param_type_util.annotate_type(model_field.type, Alias(model_field.name)),

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def resolve(param_type: str, param_format: str | None = None) -> Any:
164164
type_annotation = Literal[*enum]
165165
elif isinstance(param_def, ParamDef.UnknownType):
166166
logger.warning(
167-
f"Unable to locate a parameter type for parameter '{param_name}'. Type '{Any}' will be applied.\n"
168-
f"Unknown parameter object: {param_def.param_obj}"
167+
f"Param '{param_name}': Unable to locate a parameter type in the following parameter object. "
168+
f"Type '{Any}' will be applied:\n"
169+
f"{param_def.param_obj}"
169170
)
170171
type_annotation = Any
171172
else:

src/openapi_test_client/libraries/common/code.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def diff_code(code1: str, code2: str, fromfile: str = "before", tofile: str = "a
3939
else:
4040
color_code = ColorCodes.DEFAULT
4141
print(TAB + color(line.rstrip(), color_code=color_code)) # noqa: T201
42+
print() # noqa: T201
4243

4344

4445
def run_ruff(code: str, remove_unused_imports: bool = True) -> str:

src/openapi_test_client/scripts/generate_client.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,12 @@ def _log_errors(action: str, failed_results: list[tuple[str, Exception]]) -> Non
438438
tb = tb.tb_next
439439
error_details.append(
440440
f"API class: {api_class_name}\n"
441-
f"Error: {type(e).__name__}: {e}\n"
442-
f"File: {tb.tb_frame.f_code.co_filename} (lineno={tb.tb_lineno})"
441+
f"File: {tb.tb_frame.f_code.co_filename} (lineno={tb.tb_lineno})\n"
442+
f"Error: {type(e).__name__}: {e}"
443443
)
444444
err = f"Failed to {action} code for the following API class(es). Please fix the issue and rerun the script."
445445
logger.error(err + "\n" + list_items(error_details))
446-
if os.environ["PYTEST_CURRENT_TEST"]:
446+
if os.environ.get("PYTEST_CURRENT_TEST"):
447447
raise Exception(f"{err}\n{list_items(error_details)}")
448448

449449

tests/integration/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def external_dir(request: SubRequest, random_app_name: str) -> Generator[Path |
128128

129129
@pytest.fixture
130130
def temp_app_client(
131-
temp_dir: Path, mocker: MockerFixture, demo_app_openapi_spec_url: str, should_steram_cmd_output: bool
131+
temp_dir: Path, mocker: MockerFixture, demo_app_openapi_spec_url: str
132132
) -> Generator[OpenAPIClient, Any, None]:
133133
"""Temporary demo app API client that will be generated for a test"""
134134
app_name = f"demo_app_{random.choice(range(1, 1000))}"

0 commit comments

Comments
 (0)