Skip to content

Commit d269872

Browse files
authored
Add codemod to convert typing.Union to | (#1270)
* add union to or codemod * lint * early return
1 parent 230f177 commit d269872

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-strict
7+
8+
import libcst as cst
9+
from libcst.codemod import VisitorBasedCodemodCommand
10+
from libcst.codemod.visitors import RemoveImportsVisitor
11+
from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource
12+
13+
14+
class ConvertUnionToOrCommand(VisitorBasedCodemodCommand):
15+
DESCRIPTION: str = "Convert `Union[A, B]` to `A | B` in Python 3.10+"
16+
17+
METADATA_DEPENDENCIES = (QualifiedNameProvider,)
18+
19+
def leave_Subscript(
20+
self, original_node: cst.Subscript, updated_node: cst.Subscript
21+
) -> cst.BaseExpression:
22+
"""
23+
Given a subscript, check if it's a Union - if so, either flatten the members
24+
into a nested BitOr (if multiple members) or unwrap the type (if only one member).
25+
"""
26+
if not QualifiedNameProvider.has_name(
27+
self,
28+
original_node,
29+
QualifiedName(name="typing.Union", source=QualifiedNameSource.IMPORT),
30+
):
31+
return updated_node
32+
types = [
33+
cst.ensure_type(
34+
cst.ensure_type(s, cst.SubscriptElement).slice, cst.Index
35+
).value
36+
for s in updated_node.slice
37+
]
38+
if len(types) == 1:
39+
return types[0]
40+
else:
41+
replacement = cst.BinaryOperation(
42+
left=types[0], right=types[1], operator=cst.BitOr()
43+
)
44+
for type_ in types[2:]:
45+
replacement = cst.BinaryOperation(
46+
left=replacement, right=type_, operator=cst.BitOr()
47+
)
48+
return replacement
49+
50+
def leave_Module(
51+
self, original_node: cst.Module, updated_node: cst.Module
52+
) -> cst.Module:
53+
RemoveImportsVisitor.remove_unused_import(
54+
self.context, module="typing", obj="Union"
55+
)
56+
return updated_node
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-strict
7+
8+
from libcst.codemod import CodemodTest
9+
from libcst.codemod.commands.convert_union_to_or import ConvertUnionToOrCommand
10+
11+
12+
class TestConvertUnionToOrCommand(CodemodTest):
13+
TRANSFORM = ConvertUnionToOrCommand
14+
15+
def test_simple_union(self) -> None:
16+
before = """
17+
from typing import Union
18+
x: Union[int, str]
19+
"""
20+
after = """
21+
x: int | str
22+
"""
23+
self.assertCodemod(before, after)
24+
25+
def test_nested_union(self) -> None:
26+
before = """
27+
from typing import Union
28+
x: Union[int, Union[str, float]]
29+
"""
30+
after = """
31+
x: int | str | float
32+
"""
33+
self.assertCodemod(before, after)
34+
35+
def test_single_type_union(self) -> None:
36+
before = """
37+
from typing import Union
38+
x: Union[int]
39+
"""
40+
after = """
41+
x: int
42+
"""
43+
self.assertCodemod(before, after)
44+
45+
def test_union_with_alias(self) -> None:
46+
before = """
47+
import typing as t
48+
x: t.Union[int, str]
49+
"""
50+
after = """
51+
import typing as t
52+
x: int | str
53+
"""
54+
self.assertCodemod(before, after)
55+
56+
def test_union_with_unused_import(self) -> None:
57+
before = """
58+
from typing import Union, List
59+
x: Union[int, str]
60+
"""
61+
after = """
62+
from typing import List
63+
x: int | str
64+
"""
65+
self.assertCodemod(before, after)
66+
67+
def test_union_no_import(self) -> None:
68+
before = """
69+
x: Union[int, str]
70+
"""
71+
after = """
72+
x: Union[int, str]
73+
"""
74+
self.assertCodemod(before, after)
75+
76+
def test_union_in_function(self) -> None:
77+
before = """
78+
from typing import Union
79+
def foo(x: Union[int, str]) -> Union[float, None]:
80+
...
81+
"""
82+
after = """
83+
def foo(x: int | str) -> float | None:
84+
...
85+
"""
86+
self.assertCodemod(before, after)

0 commit comments

Comments
 (0)