diff --git a/libcst/codemod/commands/fix_variadic_callable.py b/libcst/codemod/commands/fix_variadic_callable.py new file mode 100644 index 00000000..85cb0aa0 --- /dev/null +++ b/libcst/codemod/commands/fix_variadic_callable.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import VisitorBasedCodemodCommand +from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource + + +class FixVariadicCallableCommmand(VisitorBasedCodemodCommand): + DESCRIPTION: str = ( + "Fix incorrect variadic callable type annotations from `Callable[[...], T]` to `Callable[..., T]``" + ) + + METADATA_DEPENDENCIES = (QualifiedNameProvider,) + + def leave_Subscript( + self, original_node: cst.Subscript, updated_node: cst.Subscript + ) -> cst.BaseExpression: + if QualifiedNameProvider.has_name( + self, + original_node, + QualifiedName(name="typing.Callable", source=QualifiedNameSource.IMPORT), + ): + node_matches = len(updated_node.slice) == 2 and m.matches( + updated_node.slice[0], + m.SubscriptElement( + slice=m.Index(value=m.List(elements=[m.Element(m.Ellipsis())])) + ), + ) + + if node_matches: + slices = list(updated_node.slice) + slices[0] = cst.SubscriptElement(cst.Index(cst.Ellipsis())) + return updated_node.with_changes(slice=slices) + return updated_node diff --git a/libcst/codemod/commands/tests/test_fix_variadic_callable.py b/libcst/codemod/commands/tests/test_fix_variadic_callable.py new file mode 100644 index 00000000..848f0c98 --- /dev/null +++ b/libcst/codemod/commands/tests/test_fix_variadic_callable.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict + +from libcst.codemod import CodemodTest +from libcst.codemod.commands.fix_variadic_callable import FixVariadicCallableCommmand + + +class TestFixVariadicCallableCommmand(CodemodTest): + TRANSFORM = FixVariadicCallableCommmand + + def test_callable_typing(self) -> None: + before = """ + from typing import Callable + x: Callable[[...], int] = ... + """ + after = """ + from typing import Callable + x: Callable[..., int] = ... + """ + self.assertCodemod(before, after) + + def test_callable_typing_alias(self) -> None: + before = """ + import typing as t + x: t.Callable[[...], int] = ... + """ + after = """ + import typing as t + x: t.Callable[..., int] = ... + """ + self.assertCodemod(before, after) + + def test_callable_import_alias(self) -> None: + before = """ + from typing import Callable as C + x: C[[...], int] = ... + """ + after = """ + from typing import Callable as C + x: C[..., int] = ... + """ + self.assertCodemod(before, after) + + def test_callable_with_optional(self) -> None: + before = """ + from typing import Callable + def foo(bar: Optional[Callable[[...], int]]) -> Callable[[...], int]: + ... + """ + after = """ + from typing import Callable + def foo(bar: Optional[Callable[..., int]]) -> Callable[..., int]: + ... + """ + self.assertCodemod(before, after) + + def test_callable_with_arguments(self) -> None: + before = """ + from typing import Callable + x: Callable[[int], int] + """ + after = """ + from typing import Callable + x: Callable[[int], int] + """ + self.assertCodemod(before, after) + + def test_callable_with_variadic_arguments(self) -> None: + before = """ + from typing import Callable + x: Callable[[int, int, ...], int] + """ + after = """ + from typing import Callable + x: Callable[[int, int, ...], int] + """ + self.assertCodemod(before, after) + + def test_callable_no_arguments(self) -> None: + before = """ + from typing import Callable + x: Callable + """ + after = """ + from typing import Callable + x: Callable + """ + self.assertCodemod(before, after)