diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 8b7c5df6f51f..269843ac69f8 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -27,6 +27,7 @@ LiteralType, ProperType, Type, + UnionType, get_proper_type, is_named_instance, ) @@ -54,14 +55,26 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: This plugin assumes that the provided context is an attribute access matching one of the strings found in 'ENUM_NAME_ACCESS'. """ + # This might be `SomeEnum.Field.name` case: enum_field_name = _extract_underlying_field_name(ctx.type) - if enum_field_name is None: - return ctx.default_attr_type - else: + if enum_field_name is not None: str_type = ctx.api.named_generic_type("builtins.str", []) literal_type = LiteralType(enum_field_name, fallback=str_type) return str_type.copy_modified(last_known_value=literal_type) + # Or `field: SomeEnum = SomeEnum.field; field.name` case, + # Or `field: Literal[Some.A, Some.B]; field.name` case: + enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union( + ctx.type + ) + if enum_names: + str_type = ctx.api.named_generic_type("builtins.str", []) + return make_simplified_union( + [LiteralType(enum_name, fallback=str_type) for enum_name in enum_names] + ) + + return ctx.default_attr_type + _T = TypeVar("_T") @@ -285,3 +298,27 @@ def _extract_underlying_field_name(typ: Type) -> str | None: # as a string. assert isinstance(underlying_literal.value, str) return underlying_literal.value + + +def _extract_enum_names_from_type(typ: ProperType) -> list[str] | None: + if not isinstance(typ, Instance) or not typ.type.is_enum: + return None + return typ.type.enum_members + + +def _extract_enum_names_from_literal_union(typ: ProperType) -> list[str] | None: + if not isinstance(typ, UnionType): + return None + + names = [] + for item in typ.relevant_items(): + pitem = get_proper_type(item) + if isinstance(pitem, Instance) and pitem.last_known_value and pitem.type.is_enum: + assert isinstance(pitem.last_known_value.value, str) + names.append(pitem.last_known_value.value) + elif isinstance(pitem, LiteralType): + assert isinstance(pitem.value, str) + names.append(pitem.value) + else: + return None + return names diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index a3abf53e29ac..dc0dbf4cb536 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -94,6 +94,50 @@ reveal_type(Truth.true.name) # N: Revealed type is "Literal['true']?" reveal_type(Truth.false.value) # N: Revealed type is "Literal[False]?" [builtins fixtures/bool.pyi] +[case testEnumNameValueOnType] +from enum import Enum, Flag, IntEnum, member, nonmember +class Colors(Enum): + red = 1 + blue = 2 + green = 3 + white = nonmember(0) +x: Colors = Colors.red +reveal_type(x.name) # N: Revealed type is "Union[Literal['red'], Literal['blue'], Literal['green']]" + +class Numbers(IntEnum): + one = 1 + two = 2 + three = 3 + _magic_ = member(4) +y: Numbers = Numbers.one +reveal_type(y.name) # N: Revealed type is "Union[Literal['one'], Literal['two'], Literal['three'], Literal['_magic_']]" + +class Flags(Flag): + en = "en" + us = "us" + @member + def all(self) -> str: + return "all" +z: Flags = Flags.en +reveal_type(z.name) # N: Revealed type is "Union[Literal['en'], Literal['us'], Literal['all']]" + +class Empty(Enum): ... +e: Empty +reveal_type(e.name) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + +[case testEnumNameValueOnUnionOfLiteral] +from enum import Enum +from typing import Literal +class Colors(Enum): + red = 1 + blue = 2 + green = 3 + +color: Literal[Colors.red, Colors.blue] +reveal_type(color.name) # N: Revealed type is "Union[Literal['red'], Literal['blue']]" +[builtins fixtures/tuple.pyi] + [case testEnumValueExtended] from enum import Enum class Truth(Enum): diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 0e0e2b1f344d..4648afe46521 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1710,7 +1710,7 @@ reveal_type(E.X.name) reveal_type(e.foo) reveal_type(E.Y.foo) [out] -_testEnumNameWorkCorrectlyOn311.py:11: note: Revealed type is "builtins.str" +_testEnumNameWorkCorrectlyOn311.py:11: note: Revealed type is "Union[Literal['X'], Literal['Y']]" _testEnumNameWorkCorrectlyOn311.py:12: note: Revealed type is "Union[Literal[1]?, Literal[2]?]" _testEnumNameWorkCorrectlyOn311.py:13: note: Revealed type is "Literal['X']?" _testEnumNameWorkCorrectlyOn311.py:14: note: Revealed type is "builtins.int"