Skip to content

Commit 6dcf6a7

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent fad8ffb commit 6dcf6a7

File tree

2 files changed

+151
-30
lines changed

2 files changed

+151
-30
lines changed

Diff for: pyflakes/checker.py

+85-30
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,11 @@ class Binding:
226226
the node that this binding was last used.
227227
"""
228228

229-
def __init__(self, name, source):
229+
def __init__(self, name, source, runtime=True):
230230
self.name = name
231231
self.source = source
232232
self.used = False
233+
self.runtime = runtime
233234

234235
def __str__(self):
235236
return self.name
@@ -260,8 +261,8 @@ def redefines(self, other):
260261
class Builtin(Definition):
261262
"""A definition created for all Python builtins."""
262263

263-
def __init__(self, name):
264-
super().__init__(name, None)
264+
def __init__(self, name, runtime=True):
265+
super().__init__(name, None, runtime=runtime)
265266

266267
def __repr__(self):
267268
return '<{} object {!r} at 0x{:x}>'.format(
@@ -305,10 +306,10 @@ class Importation(Definition):
305306
@type fullName: C{str}
306307
"""
307308

308-
def __init__(self, name, source, full_name=None):
309+
def __init__(self, name, source, full_name=None, runtime=True):
309310
self.fullName = full_name or name
310311
self.redefined = []
311-
super().__init__(name, source)
312+
super().__init__(name, source, runtime=runtime)
312313

313314
def redefines(self, other):
314315
if isinstance(other, SubmoduleImportation):
@@ -353,11 +354,11 @@ class SubmoduleImportation(Importation):
353354
name is also the same, to avoid false positives.
354355
"""
355356

356-
def __init__(self, name, source):
357+
def __init__(self, name, source, runtime=True):
357358
# A dot should only appear in the name when it is a submodule import
358359
assert '.' in name and (not source or isinstance(source, ast.Import))
359360
package_name = name.split('.')[0]
360-
super().__init__(package_name, source)
361+
super().__init__(package_name, source, runtime=runtime)
361362
self.fullName = name
362363

363364
def redefines(self, other):
@@ -375,7 +376,8 @@ def source_statement(self):
375376

376377
class ImportationFrom(Importation):
377378

378-
def __init__(self, name, source, module, real_name=None):
379+
def __init__(
380+
self, name, source, module, real_name=None, runtime=True):
379381
self.module = module
380382
self.real_name = real_name or name
381383

@@ -384,7 +386,7 @@ def __init__(self, name, source, module, real_name=None):
384386
else:
385387
full_name = module + '.' + self.real_name
386388

387-
super().__init__(name, source, full_name)
389+
super().__init__(name, source, full_name, runtime=runtime)
388390

389391
def __str__(self):
390392
"""Return import full name with alias."""
@@ -404,8 +406,8 @@ def source_statement(self):
404406
class StarImportation(Importation):
405407
"""A binding created by a 'from x import *' statement."""
406408

407-
def __init__(self, name, source):
408-
super().__init__('*', source)
409+
def __init__(self, name, source, runtime=True):
410+
super().__init__('*', source, runtime=runtime)
409411
# Each star importation needs a unique name, and
410412
# may not be the module name otherwise it will be deemed imported
411413
self.name = name + '.*'
@@ -494,7 +496,7 @@ class ExportBinding(Binding):
494496
C{__all__} will not have an unused import warning reported for them.
495497
"""
496498

497-
def __init__(self, name, source, scope):
499+
def __init__(self, name, source, scope, runtime=True):
498500
if '__all__' in scope and isinstance(source, ast.AugAssign):
499501
self.names = list(scope['__all__'].names)
500502
else:
@@ -525,7 +527,7 @@ def _add_to_names(container):
525527
# If not list concatenation
526528
else:
527529
break
528-
super().__init__(name, source)
530+
super().__init__(name, source, runtime=runtime)
529531

530532

531533
class Scope(dict):
@@ -732,6 +734,7 @@ class Checker:
732734
nodeDepth = 0
733735
offset = None
734736
_in_annotation = AnnotationState.NONE
737+
_in_type_check_guard = False
735738

736739
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
737740
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
@@ -1000,9 +1003,11 @@ def addBinding(self, node, value):
10001003
# then assume the rebound name is used as a global or within a loop
10011004
value.used = self.scope[value.name].used
10021005

1003-
# don't treat annotations as assignments if there is an existing value
1004-
# in scope
1005-
if value.name not in self.scope or not isinstance(value, Annotation):
1006+
# always allow the first assignment or if not already a runtime value,
1007+
# but do not shadow an existing assignment with an annotation or non
1008+
# runtime value.
1009+
if (not existing or not existing.runtime or (
1010+
not isinstance(value, Annotation) and value.runtime)):
10061011
cur_scope_pos = -1
10071012
# As per PEP 572, use scope in which outermost generator is defined
10081013
while (
@@ -1073,12 +1078,18 @@ def handleNodeLoad(self, node, parent):
10731078
self.report(messages.InvalidPrintSyntax, node)
10741079

10751080
try:
1076-
scope[name].used = (self.scope, node)
1081+
n = scope[name]
1082+
if (not n.runtime and not (
1083+
self._in_type_check_guard
1084+
or self._in_annotation)):
1085+
self.report(messages.UndefinedName, node, name)
1086+
return
1087+
1088+
n.used = (self.scope, node)
10771089

10781090
# if the name of SubImportation is same as
10791091
# alias of other Importation and the alias
10801092
# is used, SubImportation also should be marked as used.
1081-
n = scope[name]
10821093
if isinstance(n, Importation) and n._has_alias():
10831094
try:
10841095
scope[n.fullName].used = (self.scope, node)
@@ -1143,12 +1154,13 @@ def handleNodeStore(self, node):
11431154
break
11441155

11451156
parent_stmt = self.getParent(node)
1157+
runtime = not self._in_type_check_guard
11461158
if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None:
11471159
binding = Annotation(name, node)
11481160
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
11491161
parent_stmt != node._pyflakes_parent and
11501162
not self.isLiteralTupleUnpacking(parent_stmt)):
1151-
binding = Binding(name, node)
1163+
binding = Binding(name, node, runtime=runtime)
11521164
elif (
11531165
name == '__all__' and
11541166
isinstance(self.scope, ModuleScope) and
@@ -1157,11 +1169,12 @@ def handleNodeStore(self, node):
11571169
(ast.Assign, ast.AugAssign, ast.AnnAssign)
11581170
)
11591171
):
1160-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1172+
binding = ExportBinding(
1173+
name, node._pyflakes_parent, self.scope, runtime=runtime)
11611174
elif isinstance(parent_stmt, ast.NamedExpr):
1162-
binding = NamedExprAssignment(name, node)
1175+
binding = NamedExprAssignment(name, node, runtime=runtime)
11631176
else:
1164-
binding = Assignment(name, node)
1177+
binding = Assignment(name, node, runtime=runtime)
11651178
self.addBinding(node, binding)
11661179

11671180
def handleNodeDelete(self, node):
@@ -1805,7 +1818,39 @@ def DICT(self, node):
18051818
def IF(self, node):
18061819
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
18071820
self.report(messages.IfTuple, node)
1808-
self.handleChildren(node)
1821+
1822+
self.handleNode(node.test, node)
1823+
1824+
# check if the body/orelse should be handled specially because it is
1825+
# a if TYPE_CHECKING guard.
1826+
test = node.test
1827+
reverse = False
1828+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
1829+
test = test.operand
1830+
reverse = True
1831+
1832+
type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
1833+
orig = self._in_type_check_guard
1834+
1835+
# normalize body and orelse to a list
1836+
body, orelse = (
1837+
i if isinstance(i, list) else [i]
1838+
for i in (node.body, node.orelse))
1839+
1840+
# set the guard and handle the body
1841+
if type_checking and not reverse:
1842+
self._in_type_check_guard = True
1843+
1844+
for n in body:
1845+
self.handleNode(n, node)
1846+
1847+
# set the guard and handle the orelse
1848+
if type_checking:
1849+
self._in_type_check_guard = True if reverse else orig
1850+
1851+
for n in orelse:
1852+
self.handleNode(n, node)
1853+
self._in_type_check_guard = orig
18091854

18101855
IFEXP = IF
18111856

@@ -1920,7 +1965,10 @@ def FUNCTIONDEF(self, node):
19201965
with self._type_param_scope(node):
19211966
self.LAMBDA(node)
19221967

1923-
self.addBinding(node, FunctionDefinition(node.name, node))
1968+
self.addBinding(
1969+
node,
1970+
FunctionDefinition(
1971+
node.name, node, runtime=not self._in_type_check_guard))
19241972
# doctest does not process doctest within a doctest,
19251973
# or in nested functions.
19261974
if (self.withDoctest and
@@ -2005,7 +2053,10 @@ def CLASSDEF(self, node):
20052053
for stmt in node.body:
20062054
self.handleNode(stmt, node)
20072055

2008-
self.addBinding(node, ClassDefinition(node.name, node))
2056+
self.addBinding(
2057+
node,
2058+
ClassDefinition(
2059+
node.name, node, runtime=not self._in_type_check_guard))
20092060

20102061
def AUGASSIGN(self, node):
20112062
self.handleNodeLoad(node.target, node)
@@ -2038,12 +2089,15 @@ def TUPLE(self, node):
20382089
LIST = TUPLE
20392090

20402091
def IMPORT(self, node):
2092+
runtime = not self._in_type_check_guard
20412093
for alias in node.names:
20422094
if '.' in alias.name and not alias.asname:
2043-
importation = SubmoduleImportation(alias.name, node)
2095+
importation = SubmoduleImportation(
2096+
alias.name, node, runtime=runtime)
20442097
else:
20452098
name = alias.asname or alias.name
2046-
importation = Importation(name, node, alias.name)
2099+
importation = Importation(
2100+
name, node, alias.name, runtime=runtime)
20472101
self.addBinding(node, importation)
20482102

20492103
def IMPORTFROM(self, node):
@@ -2055,6 +2109,7 @@ def IMPORTFROM(self, node):
20552109

20562110
module = ('.' * node.level) + (node.module or '')
20572111

2112+
runtime = not self._in_type_check_guard
20582113
for alias in node.names:
20592114
name = alias.asname or alias.name
20602115
if node.module == '__future__':
@@ -2072,10 +2127,10 @@ def IMPORTFROM(self, node):
20722127

20732128
self.scope.importStarred = True
20742129
self.report(messages.ImportStarUsed, node, module)
2075-
importation = StarImportation(module, node)
2130+
importation = StarImportation(module, node, runtime=runtime)
20762131
else:
2077-
importation = ImportationFrom(name, node,
2078-
module, alias.name)
2132+
importation = ImportationFrom(
2133+
name, node, module, alias.name, runtime=runtime)
20792134
self.addBinding(node, importation)
20802135

20812136
def TRY(self, node):

Diff for: pyflakes/test/test_type_annotations.py

+66
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,55 @@ def f() -> T:
645645
pass
646646
""")
647647

648+
def test_typing_guard_import(self):
649+
# T is imported for runtime use
650+
self.flakes("""
651+
from typing import TYPE_CHECKING
652+
653+
if TYPE_CHECKING:
654+
from t import T
655+
656+
def f(x) -> T:
657+
from t import T
658+
659+
assert isinstance(x, T)
660+
return x
661+
""")
662+
# T is defined at runtime in one side of the if/else block
663+
self.flakes("""
664+
from typing import TYPE_CHECKING, Union
665+
666+
if TYPE_CHECKING:
667+
from t import T
668+
else:
669+
T = object
670+
671+
if not TYPE_CHECKING:
672+
U = object
673+
else:
674+
from t import U
675+
676+
def f(x) -> Union[T, U]:
677+
assert isinstance(x, (T, U))
678+
return x
679+
""")
680+
681+
def test_typing_guard_import_runtime_error(self):
682+
# T and U are not bound for runtime use
683+
self.flakes("""
684+
from typing import TYPE_CHECKING, Union
685+
686+
if TYPE_CHECKING:
687+
from t import T
688+
689+
class U:
690+
pass
691+
692+
def f(x) -> Union[T, U]:
693+
assert isinstance(x, (T, U))
694+
return x
695+
""", m.UndefinedName, m.UndefinedName)
696+
648697
def test_typing_guard_for_protocol(self):
649698
self.flakes("""
650699
from typing import TYPE_CHECKING
@@ -659,6 +708,23 @@ def f() -> int:
659708
pass
660709
""")
661710

711+
def test_typing_guard_with_elif_branch(self):
712+
# This test will not raise an error even though Protocol is not
713+
# defined outside TYPE_CHECKING because Pyflakes does not do case
714+
# analysis.
715+
self.flakes("""
716+
from typing import TYPE_CHECKING
717+
if TYPE_CHECKING:
718+
from typing import Protocol
719+
elif False:
720+
Protocol = object
721+
else:
722+
pass
723+
class C(Protocol):
724+
def f(): # type: () -> int
725+
pass
726+
""")
727+
662728
def test_typednames_correct_forward_ref(self):
663729
self.flakes("""
664730
from typing import TypedDict, List, NamedTuple

0 commit comments

Comments
 (0)