Skip to content

Commit 391ed85

Browse files
Tincheikonst
andauthored
Add support for attrs.fields (#15021)
This add support for `attrs.fields`, which is a nicer way of accessing the attrs magic attribute. Co-authored-by: Ilya Priven <[email protected]>
1 parent d98cc5a commit 391ed85

File tree

5 files changed

+109
-3
lines changed

5 files changed

+109
-3
lines changed

mypy/plugins/attrs.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
TupleType,
7070
Type,
7171
TypeOfAny,
72+
TypeType,
7273
TypeVarType,
7374
UninhabitedType,
7475
UnionType,
@@ -935,7 +936,7 @@ def add_method(
935936

936937
def _get_attrs_init_type(typ: Instance) -> CallableType | None:
937938
"""
938-
If `typ` refers to an attrs class, gets the type of its initializer method.
939+
If `typ` refers to an attrs class, get the type of its initializer method.
939940
"""
940941
magic_attr = typ.type.get(MAGIC_ATTR_NAME)
941942
if magic_attr is None or not magic_attr.plugin_generated:
@@ -1009,7 +1010,7 @@ def _get_expanded_attr_types(
10091010

10101011
def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
10111012
"""
1012-
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
1013+
"Meet" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
10131014
"""
10141015
field_to_types = defaultdict(list)
10151016
for fields in types:
@@ -1026,7 +1027,7 @@ def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
10261027

10271028
def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
10281029
"""
1029-
Generates a signature for the 'attr.evolve' function that's specific to the call site
1030+
Generate a signature for the 'attr.evolve' function that's specific to the call site
10301031
and dependent on the type of the first argument.
10311032
"""
10321033
if len(ctx.args) != 2:
@@ -1060,3 +1061,48 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10601061
fallback=ctx.default_signature.fallback,
10611062
name=f"{ctx.default_signature.name} of {inst_type_str}",
10621063
)
1064+
1065+
1066+
def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
1067+
"""Provide the signature for `attrs.fields`."""
1068+
if not ctx.args or len(ctx.args) != 1 or not ctx.args[0] or not ctx.args[0][0]:
1069+
return ctx.default_signature
1070+
1071+
# <hack>
1072+
assert isinstance(ctx.api, TypeChecker)
1073+
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
1074+
# </hack>
1075+
proper_type = get_proper_type(inst_type)
1076+
1077+
# fields(Any) -> Any, fields(type[Any]) -> Any
1078+
if (
1079+
isinstance(proper_type, AnyType)
1080+
or isinstance(proper_type, TypeType)
1081+
and isinstance(proper_type.item, AnyType)
1082+
):
1083+
return ctx.default_signature
1084+
1085+
cls = None
1086+
arg_types = ctx.default_signature.arg_types
1087+
1088+
if isinstance(proper_type, TypeVarType):
1089+
inner = get_proper_type(proper_type.upper_bound)
1090+
if isinstance(inner, Instance):
1091+
# We need to work arg_types to compensate for the attrs stubs.
1092+
arg_types = [inst_type]
1093+
cls = inner.type
1094+
elif isinstance(proper_type, CallableType):
1095+
cls = proper_type.type_object()
1096+
1097+
if cls is not None and MAGIC_ATTR_NAME in cls.names:
1098+
# This is a proper attrs class.
1099+
ret_type = cls.names[MAGIC_ATTR_NAME].type
1100+
assert ret_type is not None
1101+
return ctx.default_signature.copy_modified(arg_types=arg_types, ret_type=ret_type)
1102+
1103+
ctx.api.fail(
1104+
f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type, ctx.api.options)}"; expected an attrs class',
1105+
ctx.context,
1106+
)
1107+
1108+
return ctx.default_signature

mypy/plugins/default.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
4545
return ctypes.array_constructor_callback
4646
elif fullname == "functools.singledispatch":
4747
return singledispatch.create_singledispatch_function_callback
48+
4849
return None
4950

5051
def get_function_signature_hook(
@@ -54,6 +55,8 @@ def get_function_signature_hook(
5455

5556
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
5657
return attrs.evolve_function_sig_callback
58+
elif fullname in ("attr.fields", "attrs.fields"):
59+
return attrs.fields_function_sig_callback
5760
return None
5861

5962
def get_method_signature_hook(

test-data/unit/check-plugin-attrs.test

+53
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,59 @@ takes_attrs_cls(A(1, "")) # E: Argument 1 to "takes_attrs_cls" has incompatible
15531553
takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompatible type "Type[A]"; expected "AttrsInstance" # N: ClassVar protocol member AttrsInstance.__attrs_attrs__ can never be matched by a class object
15541554
[builtins fixtures/plugin_attrs.pyi]
15551555

1556+
[case testAttrsFields]
1557+
import attr
1558+
from attrs import fields as f # Common usage.
1559+
1560+
@attr.define
1561+
class A:
1562+
b: int
1563+
c: str
1564+
1565+
reveal_type(f(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1566+
reveal_type(f(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1567+
reveal_type(f(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1568+
f(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1569+
1570+
[builtins fixtures/plugin_attrs.pyi]
1571+
1572+
[case testAttrsGenericFields]
1573+
from typing import TypeVar
1574+
1575+
import attr
1576+
from attrs import fields
1577+
1578+
@attr.define
1579+
class A:
1580+
b: int
1581+
c: str
1582+
1583+
TA = TypeVar('TA', bound=A)
1584+
1585+
def f(t: TA) -> None:
1586+
reveal_type(fields(t)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1587+
reveal_type(fields(t)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1588+
reveal_type(fields(t).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1589+
fields(t).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1590+
1591+
1592+
[builtins fixtures/plugin_attrs.pyi]
1593+
1594+
[case testNonattrsFields]
1595+
from typing import Any, cast, Type
1596+
from attrs import fields
1597+
1598+
class A:
1599+
b: int
1600+
c: str
1601+
1602+
fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
1603+
fields(None) # E: Argument 1 to "fields" has incompatible type "None"; expected an attrs class
1604+
fields(cast(Any, 42))
1605+
fields(cast(Type[Any], 43))
1606+
1607+
[builtins fixtures/plugin_attrs.pyi]
1608+
15561609
[case testAttrsInitMethodAlwaysGenerates]
15571610
from typing import Tuple
15581611
import attr

test-data/unit/lib-stub/attr/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,5 @@ def field(
247247

248248
def evolve(inst: _T, **changes: Any) -> _T: ...
249249
def assoc(inst: _T, **changes: Any) -> _T: ...
250+
251+
def fields(cls: type) -> Any: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,5 @@ def field(
131131

132132
def evolve(inst: _T, **changes: Any) -> _T: ...
133133
def assoc(inst: _T, **changes: Any) -> _T: ...
134+
135+
def fields(cls: type) -> Any: ...

0 commit comments

Comments
 (0)