From 0320b79d43e777fbe3d978aea7233ff4264e4953 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Thu, 13 Mar 2025 12:45:24 +0300 Subject: [PATCH 1/5] Correctly infer `x: SomeEnum; x.name` as union of literal names --- mypy/plugins/enums.py | 18 +++++++++++++++--- test-data/unit/check-enum.test | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 8b7c5df6f51f..931b6aa5fcf8 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -54,14 +54,26 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: This plugin assumes that the provided context is an attribute access matching one of the strings found in 'ENUM_NAME_ACCESS'. """ + # This might be `SomeEnum.Field.name` case: enum_field_name = _extract_underlying_field_name(ctx.type) - if enum_field_name is None: - return ctx.default_attr_type - else: + if enum_field_name is not None: str_type = ctx.api.named_generic_type("builtins.str", []) literal_type = LiteralType(enum_field_name, fallback=str_type) return str_type.copy_modified(last_known_value=literal_type) + # Or `field: SomeEnum = SomeEnum.field; field.name` case: + if not isinstance(ctx.type, Instance) or not ctx.type.type.is_enum: + return ctx.default_attr_type + enum_names = ctx.type.type.enum_members + if enum_names: + str_type = ctx.api.named_generic_type("builtins.str", []) + return make_simplified_union([ + LiteralType(enum_name, fallback=str_type) + for enum_name in enum_names + ]) + + return ctx.default_attr_type + _T = TypeVar("_T") diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index a3abf53e29ac..6a0599cb2ee2 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -94,6 +94,38 @@ reveal_type(Truth.true.name) # N: Revealed type is "Literal['true']?" reveal_type(Truth.false.value) # N: Revealed type is "Literal[False]?" [builtins fixtures/bool.pyi] +[case testEnumNameValueOnType] +from enum import Enum, Flag, IntEnum, member, nonmember +class Colors(Enum): + red = 1 + blue = 2 + green = 3 + white = nonmember(0) +x: Colors = Colors.red +reveal_type(x.name) # N: Revealed type is "Union[Literal['red'], Literal['blue'], Literal['green']]" + +class Numbers(IntEnum): + one = 1 + two = 2 + three = 3 + _magic_ = member(4) +y: Numbers = Numbers.one +reveal_type(y.name) # N: Revealed type is "Union[Literal['one'], Literal['two'], Literal['three'], Literal['_magic_']]" + +class Flags(Flag): + en = "en" + us = "us" + @member + def all(self) -> str: + return "all" +z: Flags = Flags.en +reveal_type(z.name) # N: Revealed type is "Union[Literal['en'], Literal['us'], Literal['all']]" + +class Empty(Enum): ... +e: Empty +reveal_type(e.name) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + [case testEnumValueExtended] from enum import Enum class Truth(Enum): From 5ace7b9eed2cef147189c3e581614c1e29fda6df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Mar 2025 09:48:41 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/plugins/enums.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 931b6aa5fcf8..b8a6c107e486 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -67,10 +67,9 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: enum_names = ctx.type.type.enum_members if enum_names: str_type = ctx.api.named_generic_type("builtins.str", []) - return make_simplified_union([ - LiteralType(enum_name, fallback=str_type) - for enum_name in enum_names - ]) + return make_simplified_union( + [LiteralType(enum_name, fallback=str_type) for enum_name in enum_names] + ) return ctx.default_attr_type From efc6a0fa3f8fede35442f232337628a571eb5d9c Mon Sep 17 00:00:00 2001 From: sobolevn Date: Thu, 13 Mar 2025 13:42:05 +0300 Subject: [PATCH 3/5] Fix tests --- test-data/unit/pythoneval.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 0e0e2b1f344d..4648afe46521 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1710,7 +1710,7 @@ reveal_type(E.X.name) reveal_type(e.foo) reveal_type(E.Y.foo) [out] -_testEnumNameWorkCorrectlyOn311.py:11: note: Revealed type is "builtins.str" +_testEnumNameWorkCorrectlyOn311.py:11: note: Revealed type is "Union[Literal['X'], Literal['Y']]" _testEnumNameWorkCorrectlyOn311.py:12: note: Revealed type is "Union[Literal[1]?, Literal[2]?]" _testEnumNameWorkCorrectlyOn311.py:13: note: Revealed type is "Literal['X']?" _testEnumNameWorkCorrectlyOn311.py:14: note: Revealed type is "builtins.int" From 97bdaffe054e1acdec09b4a9f04fef67399c4819 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Fri, 14 Mar 2025 13:56:57 +0300 Subject: [PATCH 4/5] Address review --- mypy/plugins/enums.py | 33 +++++++++++++++++++++++++++++---- test-data/unit/check-enum.test | 12 ++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index b8a6c107e486..de58e5de63b7 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -29,6 +29,8 @@ Type, get_proper_type, is_named_instance, + UnionType, + LiteralType, ) ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | { @@ -61,10 +63,9 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: literal_type = LiteralType(enum_field_name, fallback=str_type) return str_type.copy_modified(last_known_value=literal_type) - # Or `field: SomeEnum = SomeEnum.field; field.name` case: - if not isinstance(ctx.type, Instance) or not ctx.type.type.is_enum: - return ctx.default_attr_type - enum_names = ctx.type.type.enum_members + # Or `field: SomeEnum = SomeEnum.field; field.name` case, + # Or `field: Literal[Some.A, Some.B]; field.name` case: + enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union(ctx.type) if enum_names: str_type = ctx.api.named_generic_type("builtins.str", []) return make_simplified_union( @@ -296,3 +297,27 @@ def _extract_underlying_field_name(typ: Type) -> str | None: # as a string. assert isinstance(underlying_literal.value, str) return underlying_literal.value + + +def _extract_enum_names_from_type(typ: ProperType) -> list[str] | None: + if not isinstance(typ, Instance) or not typ.type.is_enum: + return None + return typ.type.enum_members + + +def _extract_enum_names_from_literal_union(typ: ProperType) -> list[str] | None: + if not isinstance(typ, UnionType): + return None + + names = [] + for item in typ.relevant_items(): + pitem = get_proper_type(item) + if isinstance(pitem, Instance) and pitem.last_known_value and pitem.type.is_enum: + assert isinstance(pitem.last_known_value.value, str) + names.append(pitem.last_known_value.value) + elif isinstance(pitem, LiteralType): + assert isinstance(pitem.value, str) + names.append(pitem.value) + else: + return None + return names diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 6a0599cb2ee2..dc0dbf4cb536 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -126,6 +126,18 @@ e: Empty reveal_type(e.name) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] +[case testEnumNameValueOnUnionOfLiteral] +from enum import Enum +from typing import Literal +class Colors(Enum): + red = 1 + blue = 2 + green = 3 + +color: Literal[Colors.red, Colors.blue] +reveal_type(color.name) # N: Revealed type is "Union[Literal['red'], Literal['blue']]" +[builtins fixtures/tuple.pyi] + [case testEnumValueExtended] from enum import Enum class Truth(Enum): From 920e8a52ba42229a8185c650491856de04948d32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:58:57 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/plugins/enums.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index de58e5de63b7..269843ac69f8 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -27,10 +27,9 @@ LiteralType, ProperType, Type, + UnionType, get_proper_type, is_named_instance, - UnionType, - LiteralType, ) ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | { @@ -65,7 +64,9 @@ def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type: # Or `field: SomeEnum = SomeEnum.field; field.name` case, # Or `field: Literal[Some.A, Some.B]; field.name` case: - enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union(ctx.type) + enum_names = _extract_enum_names_from_type(ctx.type) or _extract_enum_names_from_literal_union( + ctx.type + ) if enum_names: str_type = ctx.api.named_generic_type("builtins.str", []) return make_simplified_union(