Skip to content

Commit 776452f

Browse files
authored
Add codemod to fix variadic callable annotations (#1269)
* add fix variadic callable codemod * format
1 parent d269872 commit 776452f

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-strict
7+
8+
import libcst as cst
9+
import libcst.matchers as m
10+
from libcst.codemod import VisitorBasedCodemodCommand
11+
from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource
12+
13+
14+
class FixVariadicCallableCommmand(VisitorBasedCodemodCommand):
15+
DESCRIPTION: str = (
16+
"Fix incorrect variadic callable type annotations from `Callable[[...], T]` to `Callable[..., T]``"
17+
)
18+
19+
METADATA_DEPENDENCIES = (QualifiedNameProvider,)
20+
21+
def leave_Subscript(
22+
self, original_node: cst.Subscript, updated_node: cst.Subscript
23+
) -> cst.BaseExpression:
24+
if QualifiedNameProvider.has_name(
25+
self,
26+
original_node,
27+
QualifiedName(name="typing.Callable", source=QualifiedNameSource.IMPORT),
28+
):
29+
node_matches = len(updated_node.slice) == 2 and m.matches(
30+
updated_node.slice[0],
31+
m.SubscriptElement(
32+
slice=m.Index(value=m.List(elements=[m.Element(m.Ellipsis())]))
33+
),
34+
)
35+
36+
if node_matches:
37+
slices = list(updated_node.slice)
38+
slices[0] = cst.SubscriptElement(cst.Index(cst.Ellipsis()))
39+
return updated_node.with_changes(slice=slices)
40+
return updated_node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-strict
7+
8+
from libcst.codemod import CodemodTest
9+
from libcst.codemod.commands.fix_variadic_callable import FixVariadicCallableCommmand
10+
11+
12+
class TestFixVariadicCallableCommmand(CodemodTest):
13+
TRANSFORM = FixVariadicCallableCommmand
14+
15+
def test_callable_typing(self) -> None:
16+
before = """
17+
from typing import Callable
18+
x: Callable[[...], int] = ...
19+
"""
20+
after = """
21+
from typing import Callable
22+
x: Callable[..., int] = ...
23+
"""
24+
self.assertCodemod(before, after)
25+
26+
def test_callable_typing_alias(self) -> None:
27+
before = """
28+
import typing as t
29+
x: t.Callable[[...], int] = ...
30+
"""
31+
after = """
32+
import typing as t
33+
x: t.Callable[..., int] = ...
34+
"""
35+
self.assertCodemod(before, after)
36+
37+
def test_callable_import_alias(self) -> None:
38+
before = """
39+
from typing import Callable as C
40+
x: C[[...], int] = ...
41+
"""
42+
after = """
43+
from typing import Callable as C
44+
x: C[..., int] = ...
45+
"""
46+
self.assertCodemod(before, after)
47+
48+
def test_callable_with_optional(self) -> None:
49+
before = """
50+
from typing import Callable
51+
def foo(bar: Optional[Callable[[...], int]]) -> Callable[[...], int]:
52+
...
53+
"""
54+
after = """
55+
from typing import Callable
56+
def foo(bar: Optional[Callable[..., int]]) -> Callable[..., int]:
57+
...
58+
"""
59+
self.assertCodemod(before, after)
60+
61+
def test_callable_with_arguments(self) -> None:
62+
before = """
63+
from typing import Callable
64+
x: Callable[[int], int]
65+
"""
66+
after = """
67+
from typing import Callable
68+
x: Callable[[int], int]
69+
"""
70+
self.assertCodemod(before, after)
71+
72+
def test_callable_with_variadic_arguments(self) -> None:
73+
before = """
74+
from typing import Callable
75+
x: Callable[[int, int, ...], int]
76+
"""
77+
after = """
78+
from typing import Callable
79+
x: Callable[[int, int, ...], int]
80+
"""
81+
self.assertCodemod(before, after)
82+
83+
def test_callable_no_arguments(self) -> None:
84+
before = """
85+
from typing import Callable
86+
x: Callable
87+
"""
88+
after = """
89+
from typing import Callable
90+
x: Callable
91+
"""
92+
self.assertCodemod(before, after)

0 commit comments

Comments
 (0)