Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[another approach] Prevent crash with Unpack of a fixed tuple in PEP695 type alias #18452

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 62 additions & 56 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4041,44 +4041,11 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
eager=eager,
python_3_12_type_alias=pep_695,
)
if isinstance(s.rvalue, (IndexExpr, CallExpr, OpExpr)) and (
not isinstance(rvalue, OpExpr)
or (self.options.python_version >= (3, 10) or self.is_stub_file)
):
# Note: CallExpr is for "void = type(None)" and OpExpr is for "X | Y" union syntax.
if not isinstance(s.rvalue.analyzed, TypeAliasExpr):
# Any existing node will be updated in-place below.
s.rvalue.analyzed = TypeAliasExpr(alias_node)
s.rvalue.analyzed.line = s.line
# we use the column from resulting target, to get better location for errors
s.rvalue.analyzed.column = res.column
elif isinstance(s.rvalue, RefExpr):
s.rvalue.is_alias_rvalue = True
self._link_type_alias_to_rvalue(s.rvalue, alias_node, s.line)

if existing:
# An alias gets updated.
updated = False
if isinstance(existing.node, TypeAlias):
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
existing.node.target = res
existing.node.alias_tvars = alias_tvars
existing.node.no_args = no_args
updated = True
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = alias_node
updated = True
if updated:
if self.final_iteration:
self.cannot_resolve_name(lvalue.name, "name", s)
return True
else:
# We need to defer so that this change can get propagated to base classes.
self.defer(s, force_progress=True)
if not self._update_type_alias(existing, alias_node, lvalue.name, s):
return True
else:
self.add_symbol(lvalue.name, alias_node, s)
if isinstance(rvalue, RefExpr) and isinstance(rvalue.node, TypeAlias):
Expand All @@ -4094,6 +4061,58 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
self.note("Use variable annotation syntax to define protocol members", s)
return True

def _link_type_alias_to_rvalue(
self, rvalue: Expression, alias_node: TypeAlias, line: int
) -> None:
if isinstance(rvalue, (IndexExpr, CallExpr)) or (
isinstance(rvalue, OpExpr)
and (self.options.python_version >= (3, 10) or self.is_stub_file)
):
# Note: CallExpr is for "void = type(None)" and OpExpr is for "X | Y" union syntax.
if not isinstance(rvalue.analyzed, TypeAliasExpr):
# Any existing node will be updated in-place below.
rvalue.analyzed = TypeAliasExpr(alias_node)
rvalue.analyzed.line = line
# we use the column from resulting target, to get better location for errors
rvalue.analyzed.column = alias_node.target.column
elif isinstance(rvalue, RefExpr):
rvalue.is_alias_rvalue = True

def _update_type_alias(
self,
existing: SymbolTableNode,
new: TypeAlias,
name: str,
stmt: AssignmentStmt | TypeAliasStmt,
) -> bool:
"""Store updated type alias information.

Returns `False` to indicate early exit (attempt to defer during final iteration).
"""
updated = False
if isinstance(existing.node, TypeAlias):
if existing.node.target != new.target:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
existing.node.target = new.target
existing.node.alias_tvars = new.alias_tvars
existing.node.no_args = new.no_args
updated = True
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = new
updated = True
if updated:
if self.final_iteration:
self.cannot_resolve_name(name, "name", stmt)
return False
else:
# We need to defer so that this change can get propagated to base classes.
self.defer(stmt, force_progress=True)
return True

def check_type_alias_type_call(self, rvalue: Expression, *, name: str) -> TypeGuard[CallExpr]:
if not isinstance(rvalue, CallExpr):
return False
Expand Down Expand Up @@ -5548,31 +5567,18 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
)
s.alias_node = alias_node

alias_ret = s.value.body.body[0]
assert isinstance(alias_ret, ReturnStmt)
assert alias_ret.expr is not None
self._link_type_alias_to_rvalue(alias_ret.expr, alias_node, s.line)

if (
existing
and isinstance(existing.node, (PlaceholderNode, TypeAlias))
and existing.node.line == s.line
):
updated = False
if isinstance(existing.node, TypeAlias):
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
existing.node.target = res
existing.node.alias_tvars = alias_tvars
updated = True
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = alias_node
updated = True

if updated:
if self.final_iteration:
self.cannot_resolve_name(s.name.name, "name", s)
return
else:
# We need to defer so that this change can get propagated to base classes.
self.defer(s, force_progress=True)
if not self._update_type_alias(existing, alias_node, s.name.name, s):
return
else:
self.add_symbol(s.name.name, alias_node, s)

Expand Down
45 changes: 44 additions & 1 deletion test-data/unit/check-python312.test
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type MyInt2 = int

def h(x: MyInt2) -> MyInt2:
return reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695Class]
class MyGen[T]:
Expand Down Expand Up @@ -49,6 +51,7 @@ def func1[T: int](x: T) -> T: ...
def func2[**P](x: Callable[P, int]) -> Callable[P, str]: ...
def func3[*Ts](x: tuple[*Ts]) -> tuple[int, *Ts]: ...
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasType]
from typing import Callable, TypeAliasType, TypeVar, TypeVarTuple
Expand Down Expand Up @@ -603,6 +606,7 @@ type A4 = int | str
a4: A4
reveal_type(a4) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/type.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasNotValidAsBaseClass]
from typing import TypeAlias
Expand Down Expand Up @@ -635,7 +639,8 @@ class Good5(B3): pass
[file m.py]
type A1 = str
type A2[T] = list[T]
[typing fixtures/typing-medium.pyi]
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasWithUnusedTypeParams]
type A[T] = int
Expand All @@ -649,6 +654,8 @@ a: A[int]
reveal_type(a) # N: Revealed type is "__main__.C[builtins.int]"

class C[T]: pass
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasForwardReference2]
type X = C
Expand All @@ -670,6 +677,8 @@ reveal_type(a) # N: Revealed type is "__main__.C[__main__.D]"

class C[T]: pass
class D: pass
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasForwardReference4]
type A = C
Expand All @@ -692,6 +701,8 @@ c: C[str]
reveal_type(a) # N: Revealed type is "builtins.str"
reveal_type(b) # N: Revealed type is "__main__.C[builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.C[builtins.str]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasWithUndefineName]
type A[T] = XXX # E: Name "XXX" is not defined
Expand All @@ -707,10 +718,13 @@ type B = int + str # E: Invalid type alias: expression is not a valid type
b: B
reveal_type(b) # N: Revealed type is "Any"
[builtins fixtures/type.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeAliasBoundForwardReference]
type B[T: Foo] = list[T]
class Foo: pass
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695UpperBound]
class D:
Expand Down Expand Up @@ -777,6 +791,8 @@ reveal_type(a) # N: Revealed type is "__main__.C[__main__.D[__main__.X]]"
reveal_type(b) # N: Revealed type is "__main__.C[__main__.E[__main__.X]]"

c: C[D[int]] # E: Type argument "D[int]" of "C" must be a subtype of "D[X]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695UpperBoundForwardReference4]
def f[T: D](a: T) -> T:
Expand Down Expand Up @@ -935,6 +951,7 @@ type C[*Ts] = tuple[*Ts, int]
a: C[str, None]
reveal_type(a) # N: Revealed type is "Tuple[builtins.str, None, builtins.int]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695IncrementalFunction]
import a
Expand Down Expand Up @@ -1036,6 +1053,7 @@ class Foo[T]: pass
type B[T] = Foo[T]

[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]
[out2]
tmp/a.py:3: note: Revealed type is "builtins.str"
tmp/a.py:5: note: Revealed type is "b.Foo[builtins.int]"
Expand Down Expand Up @@ -1249,6 +1267,7 @@ type B[T] = C[T] | list[B[T]]
b: B[int]
reveal_type(b) # N: Revealed type is "Union[__main__.C[builtins.int], builtins.list[...]]"
[builtins fixtures/type.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695BadRecursiveTypeAlias]
type A = A # E: Cannot resolve name "A" (possible cyclic definition)
Expand Down Expand Up @@ -1277,6 +1296,7 @@ f(C[C[str]]())
f(1) # E: Argument 1 to "f" has incompatible type "int"; expected "A"
f(C[int]()) # E: Argument 1 to "f" has incompatible type "C[int]"; expected "A"
[builtins fixtures/isinstance.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695InvalidGenericOrProtocolBaseClass]
from typing import Generic, Protocol, TypeVar
Expand Down Expand Up @@ -1579,6 +1599,8 @@ else:
x: T # E: Name "T" is not defined
a: A[int]
reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695UndefinedNameInAnnotation]
def f[T](x: foobar, y: T) -> T: ... # E: Name "foobar" is not defined
Expand All @@ -1592,6 +1614,8 @@ reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]"
type B[T: (int,)] = list[T] # E: Type variable must have at least two constrained types
b: B[str]
reveal_type(b) # N: Revealed type is "builtins.list[builtins.str]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695UsingTypeVariableInOwnBoundOrConstraint]
type A[T: list[T]] = str # E: Name "T" is not defined
Expand All @@ -1611,6 +1635,8 @@ a: A
reveal_type(a) # N: Revealed type is "builtins.list[Any]"
b: B
reveal_type(b) # N: Revealed type is "Any"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695GenericNamedTuple]
from typing import NamedTuple
Expand Down Expand Up @@ -1795,6 +1821,7 @@ reveal_type(y) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(z) # N: Revealed type is "builtins.int"
reveal_type(zz) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695NestedGenericClass1]
class C[T]:
Expand Down Expand Up @@ -1972,3 +1999,19 @@ class D:
class G[Q]:
def g(self, x: Q): ...
d: G[str]

[case testTypeAliasNormalization]
from collections.abc import Callable
from typing import Unpack
from typing_extensions import TypeAlias

type RK_function_args = tuple[float, int]
type RK_functionBIS = Callable[[Unpack[RK_function_args], int], int]

def ff(a: float, b: int, c: int) -> int:
return 2

bis: RK_functionBIS = ff
res: int = bis(1.0, 2, 3)
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]
2 changes: 2 additions & 0 deletions test-data/unit/check-python313.test
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ class C[T = None]: ...
def f[T = list[int]]() -> None: ...
def g[**P = [int, str]]() -> None: ...
type A[T, S = int, U = str] = list[T]
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695TypeParameterDefaultBasic]
from typing import Callable
Expand Down
2 changes: 2 additions & 0 deletions test-data/unit/diff.test
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,8 @@ __main__.A
__main__.B
__main__.C
__main__.D
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695GenericFunction]
# flags: --python-version=3.12
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Final = 0
TypedDict = 0
NoReturn = 0
NewType = 0
TypeAlias = 0
Self = 0
Unpack = 0
Callable: _SpecialForm
Expand Down
Loading