Skip to content

Commit b2df313

Browse files
authored
Merge pull request #484 from ag2ai/master
Improve code generation with fixes for recursion, aliasing, and modular references
2 parents 5ef88e7 + 748151b commit b2df313

File tree

32 files changed

+1097
-138
lines changed

32 files changed

+1097
-138
lines changed

.github/workflows/docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Setup Python
1515
uses: actions/setup-python@v1
1616
with:
17-
python-version: '3.8'
17+
python-version: '3.9'
1818
architecture: 'x64'
1919

2020
- uses: actions/cache@v1

.github/workflows/publish.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ jobs:
2020
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
2121
restore-keys: |
2222
${{ runner.os }}-pip-
23-
- name: Set up Python 3.8
23+
- name: Set up Python 3.9
2424
uses: actions/setup-python@v1
2525
with:
26-
python-version: 3.8
26+
python-version: 3.9
2727
- name: Build publish
2828
run: |
2929
python -m pip install poetry poetry-dynamic-versioning

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
strategy:
1111
fail-fast: false
1212
matrix:
13-
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
13+
python-version: ["3.9", "3.10", "3.11", "3.12"]
1414
os: [ubuntu-latest, windows-latest, macos-latest]
1515

1616
steps:

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,6 @@ fabric.properties
176176

177177
.idea
178178

179+
.vscode/
180+
179181
version.py

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Options:
4040
--specify-tags Use along with --generate-routers to generate specific routers from given list of tags.
4141
-c, --custom-visitors PATH - A custom visitor that adds variables to the template.
4242
-d, --output-model-type Specify a Pydantic base model to use (see [datamodel-code-generator](https://github.com/koxudaxi/datamodel-code-generator); default is `pydantic.BaseModel`).
43-
-p, --python-version Specify a Python version to target (default is `3.8`).
43+
-p, --python-version Specify a Python version to target (default is `3.9`).
4444
--install-completion Install completion for the current shell.
4545
--show-completion Show completion for the current shell, to copy it
4646
or customize the installation.

fastapi_code_generator/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .patches import patch_parse
2+
3+
patch_parse()

fastapi_code_generator/__main__.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
BUILTIN_VISITOR_DIR = Path(__file__).parent / "visitors"
3030

31-
MODEL_PATH: Path = Path("models.py")
31+
MODEL_PATH: Path = Path("models")
3232

3333

3434
def dynamic_load_module(module_path: Path) -> Any:
@@ -63,7 +63,7 @@ def main(
6363
DataModelType.PydanticBaseModel.value, "--output-model-type", "-d"
6464
),
6565
python_version: PythonVersion = typer.Option(
66-
PythonVersion.PY_38.value, "--python-version", "-p"
66+
PythonVersion.PY_39.value, "--python-version", "-p"
6767
),
6868
) -> None:
6969
input_name: str = input_file
@@ -72,10 +72,7 @@ def main(
7272
with open(input_file, encoding=encoding) as f:
7373
input_text = f.read()
7474

75-
if model_file:
76-
model_path = Path(model_file).with_suffix('.py')
77-
else:
78-
model_path = MODEL_PATH
75+
model_path = Path(model_file) if model_file else MODEL_PATH # pragma: no cover
7976

8077
return generate_code(
8178
input_name,
@@ -119,7 +116,7 @@ def generate_code(
119116
generate_routers: Optional[bool] = None,
120117
specify_tags: Optional[str] = None,
121118
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
122-
python_version: PythonVersion = PythonVersion.PY_38,
119+
python_version: PythonVersion = PythonVersion.PY_39,
123120
) -> None:
124121
if not model_path:
125122
model_path = MODEL_PATH
@@ -149,14 +146,16 @@ def generate_code(
149146

150147
with chdir(output_dir):
151148
models = parser.parse()
152-
output = output_dir / model_path
153149
if not models:
154150
# if no models (schemas), just generate an empty model file.
155-
modules = {output: ("", input_name)}
151+
modules = {output_dir / model_path.with_suffix('.py'): ("", input_name)}
156152
elif isinstance(models, str):
157-
modules = {output: (models, input_name)}
153+
modules = {output_dir / model_path.with_suffix('.py'): (models, input_name)}
158154
else:
159-
raise Exception('Modular references are not supported in this version')
155+
modules = {
156+
output_dir / model_path / module_name[0]: (model.body, input_name)
157+
for module_name, model in models.items()
158+
}
160159

161160
environment: Environment = Environment(
162161
loader=FileSystemLoader(

fastapi_code_generator/parser.py

+98-25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pathlib
44
import re
5+
from functools import cached_property
56
from typing import (
67
Any,
78
Callable,
@@ -42,7 +43,6 @@
4243
ResponseObject,
4344
)
4445
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
45-
from datamodel_code_generator.util import cached_property
4646
from pydantic import BaseModel, ValidationInfo
4747

4848
RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r'^application/.*json$')
@@ -93,16 +93,43 @@ class Argument(CachedPropertyModel):
9393
type_hint: UsefulStr
9494
default: Optional[UsefulStr] = None
9595
default_value: Optional[UsefulStr] = None
96+
field: Union[DataModelField, list[DataModelField], None] = None
9697
required: bool
9798

9899
def __str__(self) -> str:
99100
return self.argument
100101

101-
@cached_property
102+
@property
102103
def argument(self) -> str:
104+
if self.field is None:
105+
type_hint = self.type_hint
106+
else:
107+
type_hint = (
108+
UsefulStr(self.field.type_hint)
109+
if not isinstance(self.field, list)
110+
else UsefulStr(
111+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
112+
)
113+
)
114+
if self.default is None and self.required:
115+
return f'{self.name}: {type_hint}'
116+
return f'{self.name}: {type_hint} = {self.default}'
117+
118+
@property
119+
def snakecase(self) -> str:
120+
if self.field is None:
121+
type_hint = self.type_hint
122+
else:
123+
type_hint = (
124+
UsefulStr(self.field.type_hint)
125+
if not isinstance(self.field, list)
126+
else UsefulStr(
127+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
128+
)
129+
)
103130
if self.default is None and self.required:
104-
return f'{self.name}: {self.type_hint}'
105-
return f'{self.name}: {self.type_hint} = {self.default}'
131+
return f'{stringcase.snakecase(self.name)}: {type_hint}'
132+
return f'{stringcase.snakecase(self.name)}: {type_hint} = {self.default}'
106133

107134

108135
class Operation(CachedPropertyModel):
@@ -114,16 +141,39 @@ class Operation(CachedPropertyModel):
114141
parameters: List[Dict[str, Any]] = []
115142
responses: Dict[UsefulStr, Any] = {}
116143
deprecated: bool = False
117-
imports: List[Import] = []
118144
security: Optional[List[Dict[str, List[str]]]] = None
119145
tags: Optional[List[str]] = []
120-
arguments: str = ''
121-
snake_case_arguments: str = ''
122146
request: Optional[Argument] = None
123147
response: str = ''
124148
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
125149
return_type: str = ''
126150
callbacks: Dict[UsefulStr, List["Operation"]] = {}
151+
arguments_list: List[Argument] = []
152+
153+
@classmethod
154+
def merge_arguments_with_union(cls, arguments: List[Argument]) -> List[Argument]:
155+
grouped_arguments: DefaultDict[str, List[Argument]] = DefaultDict(list)
156+
for argument in arguments:
157+
grouped_arguments[argument.name].append(argument)
158+
159+
merged_arguments = []
160+
for argument_list in grouped_arguments.values():
161+
if len(argument_list) == 1:
162+
merged_arguments.append(argument_list[0])
163+
else:
164+
argument = argument_list[0]
165+
fields = [
166+
item
167+
for arg in argument_list
168+
if arg.field is not None
169+
for item in (
170+
arg.field if isinstance(arg.field, list) else [arg.field]
171+
)
172+
if item is not None
173+
]
174+
argument.field = fields
175+
merged_arguments.append(argument)
176+
return merged_arguments
127177

128178
@cached_property
129179
def type(self) -> UsefulStr:
@@ -132,6 +182,27 @@ def type(self) -> UsefulStr:
132182
"""
133183
return self.method
134184

185+
@property
186+
def arguments(self) -> str:
187+
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
188+
return ", ".join(argument.argument for argument in sorted_arguments)
189+
190+
@property
191+
def snake_case_arguments(self) -> str:
192+
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
193+
return ", ".join(argument.snakecase for argument in sorted_arguments)
194+
195+
@property
196+
def imports(self) -> Imports:
197+
imports = Imports()
198+
for argument in self.arguments_list:
199+
if isinstance(argument.field, list):
200+
for field in argument.field:
201+
imports.append(field.data_type.import_)
202+
elif argument.field:
203+
imports.append(argument.field.data_type.import_)
204+
return imports
205+
135206
@cached_property
136207
def root_path(self) -> UsefulStr:
137208
paths = self.path.split("/")
@@ -153,7 +224,7 @@ def function_name(self) -> str:
153224
return stringcase.snakecase(name)
154225

155226

156-
@snooper_to_methods(max_variable_length=None)
227+
@snooper_to_methods()
157228
class OpenAPIParser(OpenAPIModelParser):
158229
def __init__(
159230
self,
@@ -166,7 +237,7 @@ def __init__(
166237
base_class: Optional[str] = None,
167238
custom_template_dir: Optional[pathlib.Path] = None,
168239
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
169-
target_python_version: PythonVersion = PythonVersion.PY_37,
240+
target_python_version: PythonVersion = PythonVersion.PY_39,
170241
dump_resolve_reference_action: Optional[Callable[[Iterable[str]], str]] = None,
171242
validation: bool = False,
172243
field_constraints: bool = False,
@@ -314,6 +385,7 @@ def get_parameter_type(
314385
default=default, # type: ignore
315386
default_value=schema.default,
316387
required=field.required,
388+
field=field,
317389
)
318390

319391
def get_arguments(self, snake_case: bool, path: List[str]) -> str:
@@ -347,6 +419,10 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347419
or argument.type_hint.startswith('Optional[')
348420
)
349421

422+
# check if there are duplicate argument.name
423+
argument_names = [argument.name for argument in arguments]
424+
if len(argument_names) != len(set(argument_names)):
425+
self.imports_for_fastapi.append(Import(from_='typing', import_="Union"))
350426
return arguments
351427

352428
def parse_request_body(
@@ -466,10 +542,7 @@ def parse_operation(
466542
resolved_path = self.model_resolver.resolve_ref(path)
467543
path_name, method = path[-2:]
468544

469-
self._temporary_operation['arguments'] = self.get_arguments(
470-
snake_case=False, path=path
471-
)
472-
self._temporary_operation['snake_case_arguments'] = self.get_arguments(
545+
self._temporary_operation['arguments_list'] = self.get_argument_list(
473546
snake_case=True, path=path
474547
)
475548
main_operation = self._temporary_operation
@@ -499,11 +572,8 @@ def parse_operation(
499572
self._temporary_operation = {'_parameters': []}
500573
cb_path = path + ['callbacks', key, route, method]
501574
super().parse_operation(cb_op, cb_path)
502-
self._temporary_operation['arguments'] = self.get_arguments(
503-
snake_case=False, path=cb_path
504-
)
505-
self._temporary_operation['snake_case_arguments'] = (
506-
self.get_arguments(snake_case=True, path=cb_path)
575+
self._temporary_operation['arguments_list'] = (
576+
self.get_argument_list(snake_case=True, path=cb_path)
507577
)
508578

509579
callbacks[key].append(
@@ -527,13 +597,16 @@ def _collapse_root_model(self, data_type: DataType) -> DataType:
527597
reference = data_type.reference
528598
import functools
529599

530-
if not (
531-
reference
532-
and (
533-
len(reference.children) == 1
534-
or functools.reduce(lambda a, b: a == b, reference.children)
535-
)
536-
):
600+
try:
601+
if not (
602+
reference
603+
and (
604+
len(reference.children) == 0
605+
or functools.reduce(lambda a, b: a == b, reference.children)
606+
)
607+
):
608+
return data_type
609+
except RecursionError:
537610
return data_type
538611
source = reference.source
539612
if not isinstance(source, CustomRootType):

0 commit comments

Comments
 (0)