Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stubgen] Improve dataclass init signatures #18430

Merged
merged 1 commit into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
# Default field specifiers for dataclasses
DATACLASS_FIELD_SPECIFIERS: Final = ("dataclasses.Field", "dataclasses.field")


SELF_TVAR_NAME: Final = "_DT"
Expand All @@ -87,7 +89,7 @@
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
field_specifiers=DATACLASS_FIELD_SPECIFIERS,
)
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"
Expand Down
33 changes: 25 additions & 8 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListExpr,
MemberExpr,
MypyFile,
Expand All @@ -113,6 +114,7 @@
Var,
)
from mypy.options import Options as MypyOptions
from mypy.plugins.dataclasses import DATACLASS_FIELD_SPECIFIERS
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
from mypy.stubdoc import ArgSig, FunctionSig
Expand Down Expand Up @@ -342,11 +344,12 @@ def visit_index_expr(self, node: IndexExpr) -> str:
base = node.base.accept(self)
index = node.index.accept(self)
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
index = index[1:-1]
index = index[1:-1].rstrip(",")
return f"{base}[{index}]"

def visit_tuple_expr(self, node: TupleExpr) -> str:
return f"({', '.join(n.accept(self) for n in node.items)})"
suffix = "," if len(node.items) == 1 else ""
return f"({', '.join(n.accept(self) for n in node.items)}{suffix})"

def visit_list_expr(self, node: ListExpr) -> str:
return f"[{', '.join(n.accept(self) for n in node.items)}]"
Expand All @@ -368,6 +371,10 @@ def visit_op_expr(self, o: OpExpr) -> str:
def visit_star_expr(self, o: StarExpr) -> str:
return f"*{o.expr.accept(self)}"

def visit_lambda_expr(self, o: LambdaExpr) -> str:
# TODO: Required for among other things dataclass.field default_factory
return self.stubgen.add_name("_typeshed.Incomplete")


def find_defined_names(file: MypyFile) -> set[str]:
finder = DefinitionFinder()
Expand Down Expand Up @@ -482,6 +489,7 @@ def __init__(
self.method_names: set[str] = set()
self.processing_enum = False
self.processing_dataclass = False
self.dataclass_field_specifier: tuple[str, ...] = ()

@property
def _current_class(self) -> ClassDef | None:
Expand Down Expand Up @@ -636,8 +644,8 @@ def visit_func_def(self, o: FuncDef) -> None:
is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if is_dataclass_generated and o.name != "__init__":
# Skip methods generated by the @dataclass decorator (except for __init__)
if is_dataclass_generated:
# Skip methods generated by the @dataclass decorator
return
if (
self.is_private_name(o.name, o.fullname)
Expand Down Expand Up @@ -793,8 +801,9 @@ def visit_class_def(self, o: ClassDef) -> None:
self.add(f"{self._indent}{docstring}\n")
n = len(self._output)
self._vars.append([])
if self.analyzed and find_dataclass_transform_spec(o):
if self.analyzed and (spec := find_dataclass_transform_spec(o)):
self.processing_dataclass = True
self.dataclass_field_specifier = spec.field_specifiers
super().visit_class_def(o)
self.dedent()
self._vars.pop()
Expand All @@ -809,6 +818,7 @@ def visit_class_def(self, o: ClassDef) -> None:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self.dataclass_field_specifier = ()
self._class_stack.pop(-1)
self.processing_enum = False

Expand Down Expand Up @@ -879,8 +889,9 @@ def is_dataclass_transform(self, expr: Expression) -> bool:
expr = expr.callee
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
return True
if find_dataclass_transform_spec(expr) is not None:
if (spec := find_dataclass_transform_spec(expr)) is not None:
self.processing_dataclass = True
self.dataclass_field_specifier = spec.field_specifiers
return True
return False

Expand Down Expand Up @@ -1259,8 +1270,14 @@ def get_assign_initializer(self, rvalue: Expression) -> str:
and not isinstance(rvalue, TempNode)
):
return " = ..."
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
if self.processing_dataclass:
if isinstance(rvalue, CallExpr):
fullname = self.get_fullname(rvalue.callee)
if fullname in (self.dataclass_field_specifier or DATACLASS_FIELD_SPECIFIERS):
p = AliasPrinter(self)
return f" = {rvalue.accept(p)}"
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important

# By default, no initializer is required:
Expand Down
78 changes: 53 additions & 25 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3101,15 +3101,14 @@ import attrs

@attrs.define
class C:
x = attrs.field()
x: int = attrs.field()

[out]
import attrs

@attrs.define
class C:
x = ...
def __init__(self, x) -> None: ...
x: int = attrs.field()

[case testNamedTupleInClass]
from collections import namedtuple
Expand Down Expand Up @@ -4050,8 +4049,9 @@ def i(x=..., y=..., z=...) -> None: ...
[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, InitVar, KW_ONLY
from dataclasses import dataclass, field, Field, InitVar, KW_ONLY
from dataclasses import dataclass as dc
from datetime import datetime
from typing import ClassVar

@dataclasses.dataclass
Expand All @@ -4066,6 +4066,10 @@ class X:
h: int = 1
i: InitVar[str]
j: InitVar = 100
# Lambda not supported yet -> marked as Incomplete instead
k: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
)
non_field = None

@dcs.dataclass
Expand All @@ -4083,7 +4087,8 @@ class V: ...
[out]
import dataclasses
import dataclasses as dcs
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
from _typeshed import Incomplete
from dataclasses import Field, InitVar, KW_ONLY, dataclass, dataclass as dc, field
from typing import ClassVar

@dataclasses.dataclass
Expand All @@ -4092,12 +4097,13 @@ class X:
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
k: str = Field(default_factory=Incomplete)
non_field = ...

@dcs.dataclass
Expand All @@ -4110,8 +4116,9 @@ class W: ...
class V: ...

[case testDataclass_semanal]
from dataclasses import InitVar, dataclass, field
from dataclasses import Field, InitVar, dataclass, field
from typing import ClassVar
from datetime import datetime

@dataclass
class X:
Expand All @@ -4125,13 +4132,18 @@ class X:
h: int = 1
i: InitVar = 100
j: list[int] = field(default_factory=list)
# Lambda not supported yet -> marked as Incomplete instead
k: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
)
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, dataclass
from _typeshed import Incomplete
from dataclasses import Field, InitVar, dataclass, field
from typing import ClassVar

@dataclass
Expand All @@ -4141,13 +4153,13 @@ class X:
c: str = ...
d: ClassVar
e: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
h: int = ...
i: InitVar = ...
j: list[int] = ...
j: list[int] = field(default_factory=list)
k: str = Field(default_factory=Incomplete)
non_field = ...
def __init__(self, a, b, c=..., *, g=..., h=..., i=..., j=...) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...
Expand Down Expand Up @@ -4175,7 +4187,7 @@ class X:
class Y: ...

[out]
from dataclasses import InitVar, KW_ONLY, dataclass
from dataclasses import InitVar, KW_ONLY, dataclass, field
from typing import ClassVar

@dataclass
Expand All @@ -4184,14 +4196,13 @@ class X:
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b=..., *, g=..., h=..., i, j=...) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...
Expand Down Expand Up @@ -4236,15 +4247,13 @@ from dataclasses import dataclass
@dataclass
class X(missing.Base):
a: int
def __init__(self, *generated_args, a, **generated_kwargs) -> None: ...

@dataclass
class Y(missing.Base):
generated_args: str
generated_args_: str
generated_kwargs: float
generated_kwargs_: float
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...

[case testDataclassTransform]
# dataclass_transform detection only works with sementic analysis.
Expand Down Expand Up @@ -4298,6 +4307,7 @@ class Z(metaclass=DCMeta):

[case testDataclassTransformDecorator_semanal]
import typing_extensions
from dataclasses import field

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls):
Expand All @@ -4307,9 +4317,11 @@ def create_model(cls):
class X:
a: int
b: str = "hello"
c: bool = field(default=True)

[out]
import typing_extensions
from dataclasses import field

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls): ...
Expand All @@ -4318,9 +4330,10 @@ def create_model(cls): ...
class X:
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = field(default=True)

[case testDataclassTransformClass_semanal]
from dataclasses import field
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
Expand All @@ -4329,8 +4342,10 @@ class ModelBase: ...
class X(ModelBase):
a: int
b: str = "hello"
c: bool = field(default=True)

[out]
from dataclasses import field
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
Expand All @@ -4339,28 +4354,42 @@ class ModelBase: ...
class X(ModelBase):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = field(default=True)

[case testDataclassTransformMetaclass_semanal]
from dataclasses import field
from typing import Any
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
def custom_field(*, default: bool, kw_only: bool) -> Any: ...

@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = "hello"
c: bool = field(default=True) # should be ignored, not field_specifier here

class Y(X):
d: str = custom_field(default="Hello")

[out]
from typing import Any
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
def custom_field(*, default: bool, kw_only: bool) -> Any: ...

@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = ...

class Y(X):
d: str = custom_field(default='Hello')

[case testAlwaysUsePEP604Union]
import typing
Expand Down Expand Up @@ -4662,4 +4691,3 @@ class DCMeta(type): ...

class DC(metaclass=DCMeta):
x: str
def __init__(self, x) -> None: ...
Loading