Skip to content

Commit b59bcae

Browse files
Br1an67Viicos
andauthored
Allow dynamic models created with create_model() to be used as annotations in the Mypy plugin (pydantic#12879)
Co-authored-by: Victorien <65306057+Viicos@users.noreply.github.com>
1 parent 0b7b855 commit b59bcae

4 files changed

Lines changed: 60 additions & 0 deletions

File tree

pydantic/mypy.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from mypy.plugin import (
4848
CheckerPluginInterface,
4949
ClassDefContext,
50+
DynamicClassDefContext,
5051
MethodContext,
5152
Plugin,
5253
ReportConfigContext,
@@ -82,6 +83,7 @@
8283
CONFIGFILE_KEY = 'pydantic-mypy'
8384
METADATA_KEY = 'pydantic-mypy-metadata'
8485
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
86+
CREATE_MODEL_FULLNAME = 'pydantic.main.create_model'
8587
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
8688
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
8789
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
@@ -150,6 +152,12 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
150152
return from_attributes_callback
151153
return None
152154

155+
def get_dynamic_class_hook(self, fullname: str) -> Callable[[DynamicClassDefContext], None] | None:
156+
"""Recognize `create_model()` calls as dynamic BaseModel subclasses."""
157+
if fullname == CREATE_MODEL_FULLNAME:
158+
return self._pydantic_create_model_callback
159+
return None
160+
153161
def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]:
154162
"""Return all plugin config data.
155163
@@ -174,6 +182,33 @@ def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> Non
174182
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
175183
info_metaclass.type.dataclass_transform_spec = None
176184

185+
def _pydantic_create_model_callback(self, ctx: DynamicClassDefContext) -> None:
186+
"""Make variables assigned from `create_model()` usable as types by mypy."""
187+
# Determine the base class from __base__ argument if provided
188+
base_fullname = BASEMODEL_FULLNAME
189+
for arg_name, arg_expr in zip(ctx.call.arg_names, ctx.call.args):
190+
if arg_name == '__base__' and isinstance(arg_expr, RefExpr) and arg_expr.node is not None:
191+
if isinstance(arg_expr.node, TypeInfo):
192+
base_fullname = arg_expr.node.fullname
193+
elif isinstance(arg_expr.node, Var) and isinstance(arg_expr.node.type, Instance):
194+
base_fullname = arg_expr.node.type.type.fullname
195+
196+
base_sym = ctx.api.lookup_fully_qualified_or_none(base_fullname)
197+
if base_sym is None or not isinstance(base_sym.node, TypeInfo):
198+
# Fall back to BaseModel
199+
base_sym = ctx.api.lookup_fully_qualified_or_none(BASEMODEL_FULLNAME)
200+
if base_sym is None or not isinstance(base_sym.node, TypeInfo):
201+
return
202+
203+
base_info = base_sym.node
204+
base_instance = fill_typevars(base_info)
205+
assert isinstance(base_instance, Instance)
206+
207+
info = ctx.api.basic_new_typeinfo(ctx.name, base_instance, ctx.call.line)
208+
info.metaclass_type = base_info.metaclass_type
209+
210+
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(MDEF, info))
211+
177212

178213
class PydanticPluginConfig:
179214
"""A Pydantic mypy plugin config holder.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import BaseModel, create_model
2+
3+
4+
class Model(BaseModel):
5+
a: int
6+
7+
8+
SubModel = create_model('SubModel', __base__=Model)
9+
10+
11+
class Main(BaseModel):
12+
sub: SubModel
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import BaseModel, create_model
2+
3+
4+
class Model(BaseModel):
5+
a: int
6+
7+
8+
SubModel = create_model('SubModel', __base__=Model)
9+
10+
11+
class Main(BaseModel):
12+
sub: SubModel

tests/mypy/test_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def build_cases(
109109
('mypy-plugin.ini', 'root_models.py'),
110110
('mypy-plugin.ini', 'plugin_strict_fields.py'),
111111
('mypy-plugin.ini', 'final_with_default.py'),
112+
('mypy-plugin.ini', 'create_model_var.py'),
112113
('mypy-plugin-strict-no-any.ini', 'dataclass_no_any.py'),
113114
('mypy-plugin-very-strict.ini', 'metaclass_args.py'),
114115
('pyproject-plugin-no-strict-optional.toml', 'no_strict_optional.py'),

0 commit comments

Comments
 (0)