|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from typing import Iterable, List, cast |
| 5 | +from typing import Iterable, List, Optional, cast |
6 | 6 | from typing_extensions import Final, Literal
|
7 | 7 |
|
8 | 8 | import mypy.plugin # To avoid circular imports.
|
|
43 | 43 | Var,
|
44 | 44 | is_class_var,
|
45 | 45 | )
|
46 |
| -from mypy.plugin import SemanticAnalyzerPluginInterface |
| 46 | +from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface |
47 | 47 | from mypy.plugins.common import (
|
48 | 48 | _get_argument,
|
49 | 49 | _get_bool_argument,
|
@@ -988,3 +988,27 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
|
988 | 988 | ret_type=inst_type,
|
989 | 989 | name=f"{ctx.default_signature.name} of {inst_type_str}",
|
990 | 990 | )
|
| 991 | + |
| 992 | + |
| 993 | +def _get_cls_from_init(t: Type) -> Optional[TypeInfo]: |
| 994 | + if isinstance(t, CallableType): |
| 995 | + return t.type_object() |
| 996 | + return None |
| 997 | + |
| 998 | + |
| 999 | +def fields_function_callback(ctx: FunctionContext) -> Type: |
| 1000 | + """Provide the proper return value for `attrs.fields`.""" |
| 1001 | + if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]: |
| 1002 | + first_arg_type = ctx.arg_types[0][0] |
| 1003 | + cls = _get_cls_from_init(first_arg_type) |
| 1004 | + if cls is not None: |
| 1005 | + if MAGIC_ATTR_NAME in cls.names: |
| 1006 | + # This is a proper attrs class. |
| 1007 | + ret_type = cls.names[MAGIC_ATTR_NAME].type |
| 1008 | + return ret_type |
| 1009 | + else: |
| 1010 | + ctx.api.fail( |
| 1011 | + f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class', |
| 1012 | + ctx.context, |
| 1013 | + ) |
| 1014 | + return ctx.default_return_type |
0 commit comments