Skip to content

Improve code generation with fixes for recursion, aliasing, and modular references #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 16, 2025
Merged
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: '3.8'
python-version: '3.9'
architecture: 'x64'

- uses: actions/cache@v1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.9
- name: Build publish
run: |
python -m pip install poetry poetry-dynamic-versioning
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,6 @@ fabric.properties

.idea

.vscode/

version.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Options:
--specify-tags Use along with --generate-routers to generate specific routers from given list of tags.
-c, --custom-visitors PATH - A custom visitor that adds variables to the template.
-d, --output-model-type Specify a Pydantic base model to use (see [datamodel-code-generator](https://github.com/koxudaxi/datamodel-code-generator); default is `pydantic.BaseModel`).
-p, --python-version Specify a Python version to target (default is `3.8`).
-p, --python-version Specify a Python version to target (default is `3.9`).
--install-completion Install completion for the current shell.
--show-completion Show completion for the current shell, to copy it
or customize the installation.
Expand Down
3 changes: 3 additions & 0 deletions fastapi_code_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .patches import patch_parse

patch_parse()
21 changes: 10 additions & 11 deletions fastapi_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

BUILTIN_VISITOR_DIR = Path(__file__).parent / "visitors"

MODEL_PATH: Path = Path("models.py")
MODEL_PATH: Path = Path("models")


def dynamic_load_module(module_path: Path) -> Any:
Expand Down Expand Up @@ -63,7 +63,7 @@ def main(
DataModelType.PydanticBaseModel.value, "--output-model-type", "-d"
),
python_version: PythonVersion = typer.Option(
PythonVersion.PY_38.value, "--python-version", "-p"
PythonVersion.PY_39.value, "--python-version", "-p"
),
) -> None:
input_name: str = input_file
Expand All @@ -72,10 +72,7 @@ def main(
with open(input_file, encoding=encoding) as f:
input_text = f.read()

if model_file:
model_path = Path(model_file).with_suffix('.py')
else:
model_path = MODEL_PATH
model_path = Path(model_file) if model_file else MODEL_PATH # pragma: no cover

return generate_code(
input_name,
Expand Down Expand Up @@ -119,7 +116,7 @@ def generate_code(
generate_routers: Optional[bool] = None,
specify_tags: Optional[str] = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
python_version: PythonVersion = PythonVersion.PY_38,
python_version: PythonVersion = PythonVersion.PY_39,
) -> None:
if not model_path:
model_path = MODEL_PATH
Expand Down Expand Up @@ -149,14 +146,16 @@ def generate_code(

with chdir(output_dir):
models = parser.parse()
output = output_dir / model_path
if not models:
# if no models (schemas), just generate an empty model file.
modules = {output: ("", input_name)}
modules = {output_dir / model_path.with_suffix('.py'): ("", input_name)}
elif isinstance(models, str):
modules = {output: (models, input_name)}
modules = {output_dir / model_path.with_suffix('.py'): (models, input_name)}
else:
raise Exception('Modular references are not supported in this version')
modules = {
output_dir / model_path / module_name[0]: (model.body, input_name)
for module_name, model in models.items()
}

environment: Environment = Environment(
loader=FileSystemLoader(
Expand Down
123 changes: 98 additions & 25 deletions fastapi_code_generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pathlib
import re
from functools import cached_property
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -42,7 +43,6 @@
ResponseObject,
)
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
from datamodel_code_generator.util import cached_property
from pydantic import BaseModel, ValidationInfo

RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r'^application/.*json$')
Expand Down Expand Up @@ -93,16 +93,43 @@
type_hint: UsefulStr
default: Optional[UsefulStr] = None
default_value: Optional[UsefulStr] = None
field: Union[DataModelField, list[DataModelField], None] = None
required: bool

def __str__(self) -> str:
return self.argument

@cached_property
@property
def argument(self) -> str:
if self.field is None:
type_hint = self.type_hint

Check warning on line 105 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L104-L105

Added lines #L104 - L105 were not covered by tests
else:
type_hint = (

Check warning on line 107 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L107

Added line #L107 was not covered by tests
UsefulStr(self.field.type_hint)
if not isinstance(self.field, list)
else UsefulStr(
f"Union[{', '.join(field.type_hint for field in self.field)}]"
)
)
if self.default is None and self.required:
return f'{self.name}: {type_hint}'
return f'{self.name}: {type_hint} = {self.default}'

Check warning on line 116 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L114-L116

Added lines #L114 - L116 were not covered by tests

@property
def snakecase(self) -> str:
if self.field is None:
type_hint = self.type_hint
else:
type_hint = (
UsefulStr(self.field.type_hint)
if not isinstance(self.field, list)
else UsefulStr(
f"Union[{', '.join(field.type_hint for field in self.field)}]"
)
)
if self.default is None and self.required:
return f'{self.name}: {self.type_hint}'
return f'{self.name}: {self.type_hint} = {self.default}'
return f'{stringcase.snakecase(self.name)}: {type_hint}'
return f'{stringcase.snakecase(self.name)}: {type_hint} = {self.default}'


class Operation(CachedPropertyModel):
Expand All @@ -114,16 +141,39 @@
parameters: List[Dict[str, Any]] = []
responses: Dict[UsefulStr, Any] = {}
deprecated: bool = False
imports: List[Import] = []
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[str]] = []
arguments: str = ''
snake_case_arguments: str = ''
request: Optional[Argument] = None
response: str = ''
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
return_type: str = ''
callbacks: Dict[UsefulStr, List["Operation"]] = {}
arguments_list: List[Argument] = []

@classmethod
def merge_arguments_with_union(cls, arguments: List[Argument]) -> List[Argument]:
grouped_arguments: DefaultDict[str, List[Argument]] = DefaultDict(list)
for argument in arguments:
grouped_arguments[argument.name].append(argument)

merged_arguments = []
for argument_list in grouped_arguments.values():
if len(argument_list) == 1:
merged_arguments.append(argument_list[0])
else:
argument = argument_list[0]
fields = [
item
for arg in argument_list
if arg.field is not None
for item in (
arg.field if isinstance(arg.field, list) else [arg.field]
)
if item is not None
]
argument.field = fields
merged_arguments.append(argument)
return merged_arguments

@cached_property
def type(self) -> UsefulStr:
Expand All @@ -132,6 +182,27 @@
"""
return self.method

@property
def arguments(self) -> str:
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
return ", ".join(argument.argument for argument in sorted_arguments)

Check warning on line 188 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L187-L188

Added lines #L187 - L188 were not covered by tests

@property
def snake_case_arguments(self) -> str:
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
return ", ".join(argument.snakecase for argument in sorted_arguments)

@property
def imports(self) -> Imports:
imports = Imports()
for argument in self.arguments_list:
if isinstance(argument.field, list):
for field in argument.field:
imports.append(field.data_type.import_)

Check warning on line 201 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L200-L201

Added lines #L200 - L201 were not covered by tests
elif argument.field:
imports.append(argument.field.data_type.import_)
return imports

@cached_property
def root_path(self) -> UsefulStr:
paths = self.path.split("/")
Expand All @@ -153,7 +224,7 @@
return stringcase.snakecase(name)


@snooper_to_methods(max_variable_length=None)
@snooper_to_methods()
class OpenAPIParser(OpenAPIModelParser):
def __init__(
self,
Expand All @@ -166,7 +237,7 @@
base_class: Optional[str] = None,
custom_template_dir: Optional[pathlib.Path] = None,
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
target_python_version: PythonVersion = PythonVersion.PY_37,
target_python_version: PythonVersion = PythonVersion.PY_39,
dump_resolve_reference_action: Optional[Callable[[Iterable[str]], str]] = None,
validation: bool = False,
field_constraints: bool = False,
Expand Down Expand Up @@ -314,6 +385,7 @@
default=default, # type: ignore
default_value=schema.default,
required=field.required,
field=field,
)

def get_arguments(self, snake_case: bool, path: List[str]) -> str:
Expand Down Expand Up @@ -347,6 +419,10 @@
or argument.type_hint.startswith('Optional[')
)

# check if there are duplicate argument.name
argument_names = [argument.name for argument in arguments]
if len(argument_names) != len(set(argument_names)):
self.imports_for_fastapi.append(Import(from_='typing', import_="Union"))
return arguments

def parse_request_body(
Expand Down Expand Up @@ -466,10 +542,7 @@
resolved_path = self.model_resolver.resolve_ref(path)
path_name, method = path[-2:]

self._temporary_operation['arguments'] = self.get_arguments(
snake_case=False, path=path
)
self._temporary_operation['snake_case_arguments'] = self.get_arguments(
self._temporary_operation['arguments_list'] = self.get_argument_list(
snake_case=True, path=path
)
main_operation = self._temporary_operation
Expand Down Expand Up @@ -499,11 +572,8 @@
self._temporary_operation = {'_parameters': []}
cb_path = path + ['callbacks', key, route, method]
super().parse_operation(cb_op, cb_path)
self._temporary_operation['arguments'] = self.get_arguments(
snake_case=False, path=cb_path
)
self._temporary_operation['snake_case_arguments'] = (
self.get_arguments(snake_case=True, path=cb_path)
self._temporary_operation['arguments_list'] = (

Check warning on line 575 in fastapi_code_generator/parser.py

View check run for this annotation

Codecov / codecov/patch

fastapi_code_generator/parser.py#L575

Added line #L575 was not covered by tests
self.get_argument_list(snake_case=True, path=cb_path)
)

callbacks[key].append(
Expand All @@ -527,13 +597,16 @@
reference = data_type.reference
import functools

if not (
reference
and (
len(reference.children) == 1
or functools.reduce(lambda a, b: a == b, reference.children)
)
):
try:
if not (
reference
and (
len(reference.children) == 0
or functools.reduce(lambda a, b: a == b, reference.children)
)
):
return data_type
except RecursionError:
return data_type
source = reference.source
if not isinstance(source, CustomRootType):
Expand Down
Loading