Skip to content

Commit f9dcb32

Browse files
committed
Add support for try blocks in stubs
On a local test I saw that kiwisolver/_cext.pyi and rapidfuzz/process.pyi contain try blocks. Not standards-compliant but I'd rather have typeshed-client not crash on this.
1 parent 8f622f5 commit f9dcb32

File tree

4 files changed

+27
-0
lines changed

4 files changed

+27
-0
lines changed

tests/test.py

+9
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,15 @@ def test_dot_import(self) -> None:
191191
names["f"].ast, typeshed_client.ImportedName(path, "f")
192192
)
193193

194+
def test_try(self) -> None:
195+
ctx = get_context((3, 10))
196+
names = get_stub_names("tryexcept", search_context=ctx)
197+
assert names is not None
198+
self.assertEqual(set(names), {"np", "f", "x"})
199+
self.check_nameinfo(names, "np", typeshed_client.ImportedName)
200+
self.check_nameinfo(names, "f", ast.FunctionDef)
201+
self.check_nameinfo(names, "x", ast.AnnAssign)
202+
194203
def check_nameinfo(
195204
self,
196205
names: typeshed_client.NameDict,

tests/typeshed/VERSIONS

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ about: 3.5-
1717
importabout: 3.5-
1818
tupleall: 3.5-
1919
starimportall: 3.5-
20+
tryexcept: 3.5-

tests/typeshed/tryexcept.pyi

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
try:
2+
import numpy as np
3+
4+
def f(x: np.int64) -> np.int64: ...
5+
6+
except ImportError:
7+
pass
8+
finally:
9+
x: int

typeshed_client/parser.py

+8
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,14 @@ def visit_If(self, node: ast.If) -> Iterable[NameInfo]:
252252
for stmt in node.orelse:
253253
yield from self.visit(stmt)
254254

255+
def visit_Try(self, node: ast.Try) -> Iterable[NameInfo]:
256+
# try-except sometimes gets used with conditional imports. We assume
257+
# the try block is always executed.
258+
for stmt in node.body:
259+
yield from self.visit(stmt)
260+
for stmt in node.finalbody:
261+
yield from self.visit(stmt)
262+
255263
def visit_Assert(self, node: ast.Assert) -> Iterable[NameInfo]:
256264
visitor = _LiteralEvalVisitor(self.ctx)
257265
value = visitor.visit(node.test)

0 commit comments

Comments
 (0)