diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 787d5cb89b0a..fd8101837427 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -95,7 +95,6 @@ def __init__( always_covariant: bool = False, ignore_promotions: bool = False, # Proper subtype flags - erase_instances: bool = False, keep_erased_types: bool = False, options: Options | None = None, ) -> None: @@ -104,7 +103,6 @@ def __init__( self.ignore_declared_variance = ignore_declared_variance self.always_covariant = always_covariant self.ignore_promotions = ignore_promotions - self.erase_instances = erase_instances self.keep_erased_types = keep_erased_types self.options = options @@ -114,7 +112,7 @@ def check_context(self, proper_subtype: bool) -> None: if proper_subtype: assert not self.ignore_pos_arg_names and not self.ignore_declared_variance else: - assert not self.erase_instances and not self.keep_erased_types + assert not self.keep_erased_types def is_subtype( @@ -191,7 +189,6 @@ def is_proper_subtype( *, subtype_context: SubtypeContext | None = None, ignore_promotions: bool = False, - erase_instances: bool = False, keep_erased_types: bool = False, ) -> bool: """Is left a proper subtype of right? @@ -199,19 +196,16 @@ def is_proper_subtype( For proper subtypes, there's no need to rely on compatibility due to Any types. Every usable type is a proper subtype of itself. - If erase_instances is True, erase left instance *after* mapping it to supertype - (this is useful for runtime isinstance() checks). If keep_erased_types is True, - do not consider ErasedType a subtype of all types (used by type inference against unions). + If keep_erased_types is True, do not consider ErasedType a subtype + of all types (used by type inference against unions). """ if subtype_context is None: subtype_context = SubtypeContext( - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types, + ignore_promotions=ignore_promotions, keep_erased_types=keep_erased_types ) else: assert not any( - {ignore_promotions, erase_instances, keep_erased_types} + {ignore_promotions, keep_erased_types} ), "Don't pass both context and individual flags" if type_state.is_assumed_proper_subtype(left, right): return True @@ -403,7 +397,6 @@ def build_subtype_kind(subtype_context: SubtypeContext, proper_subtype: bool) -> subtype_context.ignore_declared_variance, subtype_context.always_covariant, subtype_context.ignore_promotions, - subtype_context.erase_instances, subtype_context.keep_erased_types, ) @@ -527,10 +520,6 @@ def visit_instance(self, left: Instance) -> bool: ) and not self.subtype_context.ignore_declared_variance: # Map left type to corresponding right instances. t = map_instance_to_supertype(left, right.type) - if self.subtype_context.erase_instances: - erased = erase_type(t) - assert isinstance(erased, Instance) - t = erased nominal = True if right.type.has_type_var_tuple_type: # For variadic instances we simply find the correct type argument mappings, @@ -1929,7 +1918,8 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: ideal result (just t is a valid result). This is used for type inference of runtime type checks such as - isinstance(). Currently, this just removes elements of a union type. + isinstance() or TypeIs. Currently, this just removes elements + of a union type. """ p_t = get_proper_type(t) if isinstance(p_t, UnionType): @@ -1938,46 +1928,66 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: new_items = [ restrict_subtype_away(item, s) for item in p_t.relevant_items() - if (isinstance(get_proper_type(item), AnyType) or not covers_at_runtime(item, s)) + if isinstance(get_proper_type(item), UnionType) or not covers_type(item, s) ] return UnionType.make_union(new_items) elif isinstance(p_t, TypeVarType): return p_t.copy_modified(upper_bound=restrict_subtype_away(p_t.upper_bound, s)) - elif covers_at_runtime(t, s): + elif covers_type(t, s): return UninhabitedType() else: return t -def covers_at_runtime(item: Type, supertype: Type) -> bool: - """Will isinstance(item, supertype) always return True at runtime?""" +def covers_type(item: Type, supertype: Type) -> bool: + """Returns if item is covered by supertype. + + Any types (or fallbacks to any) should never cover or be covered. + + Assumes that item is not a Union type. + + Examples: + int covered by int + List[int] covered by List[Any] + A covered by Union[A, Any] + Any NOT covered by int + int NOT covered by Any + """ item = get_proper_type(item) supertype = get_proper_type(supertype) - # Since runtime type checks will ignore type arguments, erase the types. - supertype = erase_type(supertype) - if is_proper_subtype( - erase_type(item), supertype, ignore_promotions=True, erase_instances=True + assert not isinstance(item, UnionType) + + # Handle possible Any types that should not be covered: + if isinstance(item, AnyType) or isinstance(supertype, AnyType): + return False + elif (isinstance(item, Instance) and item.type.fallback_to_any) or ( + isinstance(supertype, Instance) and supertype.type.fallback_to_any ): - return True - if isinstance(supertype, Instance): + return is_same_type(item, supertype) + + if isinstance(supertype, UnionType): + # Special case that cannot be handled by is_subtype, because it would + # not ignore the Any types: + return any(covers_type(item, t) for t in supertype.relevant_items()) + elif isinstance(supertype, Instance): if supertype.type.is_protocol: # TODO: Implement more robust support for runtime isinstance() checks, see issue #3827. - if is_proper_subtype(item, supertype, ignore_promotions=True): + if is_proper_subtype(item, erase_type(supertype), ignore_promotions=True): return True if isinstance(item, TypedDictType): # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). if supertype.type.fullname == "builtins.dict": return True elif isinstance(item, TypeVarType): - if is_proper_subtype(item.upper_bound, supertype, ignore_promotions=True): + if is_proper_subtype(item.upper_bound, erase_type(supertype), ignore_promotions=True): return True elif isinstance(item, Instance) and supertype.type.fullname == "builtins.int": # "int" covers all native int types if item.type.fullname in MYPYC_NATIVE_INT_NAMES: return True - # TODO: Add more special cases. - return False + + return is_subtype(item, supertype, ignore_promotions=True) def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 8fa1bc1ca1ac..3650f2f47f24 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2161,6 +2161,23 @@ else: reveal_type(z) # N: Revealed type is "Any" [builtins fixtures/isinstance.pyi] +[case testIsinstanceSubclassAny] +from typing import Any, Union +X: Any +class BadParent(X): pass +class GoodParent(object): pass +a: Union[GoodParent, BadParent] +if isinstance(a, BadParent): + reveal_type(a) # N: Revealed type is "__main__.BadParent" +else: + reveal_type(a) # N: Revealed type is "Union[__main__.GoodParent, __main__.BadParent]" +b: Union[int, BadParent] +if isinstance(b, (X, GoodParent)): + reveal_type(b) # N: Revealed type is "Union[Any, __main__.BadParent]" +else: + reveal_type(b) # N: Revealed type is "Union[builtins.int, __main__.BadParent]" +[builtins fixtures/isinstance.pyi] + [case testIsInstanceInitialNoneCheckSkipsImpossibleCasesNoStrictOptional] from typing import Optional, Union diff --git a/test-data/unit/check-typeis.test b/test-data/unit/check-typeis.test index 6b96845504ab..1da1088ff128 100644 --- a/test-data/unit/check-typeis.test +++ b/test-data/unit/check-typeis.test @@ -125,6 +125,17 @@ def main(a: object) -> None: reveal_type(a) # N: Revealed type is "Union[builtins.int, builtins.str]" [builtins fixtures/tuple.pyi] +[case testTypeIsUnionWithTypeParams] +from typing_extensions import TypeIs +from typing import Iterable, List, Union +def is_iterable_int(val: object) -> TypeIs[Iterable[int]]: pass +def main(a: Union[List[int], List[str]]) -> None: + if is_iterable_int(a): + reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]" + else: + reveal_type(a) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/tuple.pyi] + [case testTypeIsNonzeroFloat] from typing_extensions import TypeIs def is_nonzero(a: object) -> TypeIs[float]: pass @@ -155,6 +166,30 @@ class C: def is_float(self, a: object) -> TypeIs[float]: pass [builtins fixtures/tuple.pyi] +[case testTypeIsTypeAny] +from typing_extensions import TypeIs +from typing import Any, Type, Union +class A: ... +def is_class(x: object) -> TypeIs[Type[Any]]: ... +def main(a: Union[A, Type[A]]) -> None: + if is_class(a): + reveal_type(a) # N: Revealed type is "Type[Any]" + else: + reveal_type(a) # N: Revealed type is "__main__.A" +[builtins fixtures/tuple.pyi] + +[case testTypeIsAwaitableAny] +from typing_extensions import TypeIs +from typing import Any, Awaitable, TypeVar, Union +T = TypeVar('T') +def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: pass +def main(a: Union[Awaitable[T], T]) -> None: + if is_awaitable(a): + reveal_type(a) # N: Revealed type is "Union[typing.Awaitable[T`-1], typing.Awaitable[Any]]" + else: + reveal_type(a) # N: Revealed type is "T`-1" +[builtins fixtures/tuple.pyi] + [case testTypeIsCrossModule] import guard from points import Point