40
40
generate_module ,
41
41
generate_pass ,
42
42
generate_pydantic_field ,
43
+ model_has_forward_refs ,
43
44
)
44
45
from ..exceptions import NotSupported , ParsingError
45
46
from ..plugins .manager import PluginManager
@@ -158,7 +159,7 @@ def generate(self) -> ast.Module:
158
159
model_rebuild_calls = [
159
160
generate_expr (generate_method_call (class_def .name , MODEL_REBUILD_METHOD ))
160
161
for class_def in self ._class_defs
161
- if self . include_model_rebuild (class_def )
162
+ if model_has_forward_refs (class_def )
162
163
]
163
164
164
165
module_body = (
@@ -174,11 +175,6 @@ def generate(self) -> ast.Module:
174
175
)
175
176
return module
176
177
177
- def include_model_rebuild (self , class_def : ast .ClassDef ) -> bool :
178
- visitor = ClassDefNamesVisitor ()
179
- visitor .visit (class_def )
180
- return visitor .found_name_with_quote
181
-
182
178
def get_imports (self ) -> List [ast .ImportFrom ]:
183
179
return self ._imports
184
180
@@ -576,19 +572,3 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode:
576
572
copied_node = deepcopy (node )
577
573
visit (copied_node , RemoveMixinVisitor ())
578
574
return copied_node
579
-
580
-
581
- class ClassDefNamesVisitor (ast .NodeVisitor ):
582
- def __init__ (self ):
583
- self .found_name_with_quote = False
584
-
585
- def visit_Name (self , node ): # pylint: disable=C0103
586
- if '"' in node .id :
587
- self .found_name_with_quote = True
588
- self .generic_visit (node )
589
-
590
- def visit_Subscript (self , node ): # pylint: disable=C0103
591
- if isinstance (node .value , ast .Name ) and node .value .id == "Literal" :
592
- return
593
-
594
- self .generic_visit (node )
0 commit comments