Skip to content

Commit f3f69e6

Browse files
committed
Add support for attrs.fields
1 parent e21ddbf commit f3f69e6

File tree

5 files changed

+60
-3
lines changed

5 files changed

+60
-3
lines changed

mypy/plugins/attrs.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterable, List, cast
5+
from typing import Iterable, List, Optional, cast
66
from typing_extensions import Final, Literal
77

88
import mypy.plugin # To avoid circular imports.
@@ -43,7 +43,7 @@
4343
Var,
4444
is_class_var,
4545
)
46-
from mypy.plugin import SemanticAnalyzerPluginInterface
46+
from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface
4747
from mypy.plugins.common import (
4848
_get_argument,
4949
_get_bool_argument,
@@ -988,3 +988,27 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
988988
ret_type=inst_type,
989989
name=f"{ctx.default_signature.name} of {inst_type_str}",
990990
)
991+
992+
993+
def _get_cls_from_init(t: Type) -> Optional[TypeInfo]:
994+
if isinstance(t, CallableType):
995+
return t.type_object()
996+
return None
997+
998+
999+
def fields_function_callback(ctx: FunctionContext) -> Type:
1000+
"""Provide the proper return value for `attrs.fields`."""
1001+
if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]:
1002+
first_arg_type = ctx.arg_types[0][0]
1003+
cls = _get_cls_from_init(first_arg_type)
1004+
if cls is not None:
1005+
if MAGIC_ATTR_NAME in cls.names:
1006+
# This is a proper attrs class.
1007+
ret_type = cls.names[MAGIC_ATTR_NAME].type
1008+
return ret_type
1009+
else:
1010+
ctx.api.fail(
1011+
f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class',
1012+
ctx.context,
1013+
)
1014+
return ctx.default_return_type

mypy/plugins/default.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class DefaultPlugin(Plugin):
3939
"""Type checker plugin that is enabled by default."""
4040

4141
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
42-
from mypy.plugins import ctypes, singledispatch
42+
from mypy.plugins import attrs, ctypes, singledispatch
4343

4444
if fullname == "ctypes.Array":
4545
return ctypes.array_constructor_callback
4646
elif fullname == "functools.singledispatch":
4747
return singledispatch.create_singledispatch_function_callback
48+
elif fullname in ("attr.fields", "attrs.fields"):
49+
return attrs.fields_function_callback
4850
return None
4951

5052
def get_function_signature_hook(

test-data/unit/check-attr.test

+27
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,33 @@ takes_attrs_cls(A(1, "")) # E: Argument 1 to "takes_attrs_cls" has incompatible
15481548
takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompatible type "Type[A]"; expected "AttrsInstance" # N: ClassVar protocol member AttrsInstance.__attrs_attrs__ can never be matched by a class object
15491549
[builtins fixtures/attr.pyi]
15501550

1551+
[case testAttrsFields]
1552+
import attr
1553+
from attrs import fields
1554+
1555+
@attr.define
1556+
class A:
1557+
b: int
1558+
c: str
1559+
1560+
reveal_type(fields(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561+
reveal_type(fields(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562+
reveal_type(fields(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563+
fields(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1564+
1565+
[builtins fixtures/attr.pyi]
1566+
1567+
[case testNonattrsFields]
1568+
from attrs import fields
1569+
1570+
class A:
1571+
b: int
1572+
c: str
1573+
1574+
fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
1575+
1576+
[builtins fixtures/attr.pyi]
1577+
15511578
[case testAttrsInitMethodAlwaysGenerates]
15521579
from typing import Tuple
15531580
import attr

test-data/unit/lib-stub/attr/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,5 @@ def field(
247247

248248
def evolve(inst: _T, **changes: Any) -> _T: ...
249249
def assoc(inst: _T, **changes: Any) -> _T: ...
250+
251+
def fields(cls: _C) -> Any: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,5 @@ def field(
129129

130130
def evolve(inst: _T, **changes: Any) -> _T: ...
131131
def assoc(inst: _T, **changes: Any) -> _T: ...
132+
133+
def fields(cls: _C) -> Any: ...

0 commit comments

Comments
 (0)