Skip to content

Commit a33401e

Browse files
author
MarcoGorelli
committed
put simple node check back
1 parent b679644 commit a33401e

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

auto_walrus.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
SEP_SYMBOLS = frozenset(('(', ')', ',', ':'))
1212
# name, lineno, col_offset, end_lineno, end_col_offset
1313
Token = Tuple[str, int, int, int, int]
14-
14+
SIMPLE_NODE = (ast.Name, ast.Constant)
1515
ENDS_WITH_COMMENT = re.compile(r'#.*$')
1616

1717

@@ -25,6 +25,22 @@ def name_lineno_coloffset(tokens: Token) -> tuple[str, int, int]:
2525
return (tokens[0], tokens[1], tokens[2])
2626

2727

28+
def is_simple_test(node: ast.AST) -> bool:
29+
return (
30+
isinstance(node, SIMPLE_NODE)
31+
or (
32+
isinstance(node, ast.Compare)
33+
and isinstance(node.left, SIMPLE_NODE)
34+
and (
35+
all(
36+
isinstance(_node, SIMPLE_NODE)
37+
for _node in node.comparators
38+
)
39+
)
40+
)
41+
)
42+
43+
2844
def record_name_lineno_coloffset(
2945
node: ast.Name,
3046
end_lineno: int | None = None,
@@ -186,9 +202,10 @@ def visit_function_def(
186202
if isinstance(_node, ast.Assign):
187203
process_assign(_node, assignments, related_vars)
188204
elif isinstance(_node, ast.If):
189-
ifs.update(process_if(_node, in_body_vars))
205+
if is_simple_test(_node.test):
206+
ifs.update(process_if(_node, in_body_vars))
190207
for __node in _node.orelse:
191-
if isinstance(__node, ast.If):
208+
if isinstance(__node, ast.If) and is_simple_test(__node.test):
192209
ifs.update(process_if(__node, in_body_vars))
193210

194211
sorted_names = sorted(names, key=lambda x: (x[1], x[2]))

tests/main_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,20 @@ def test_rewrite(src: str, expected: str) -> None:
179179
' if a:\n'
180180
' print(a)\n'
181181
' a = 2\n',
182-
'n = 10\n'
183-
'if True:\n'
184-
' pass\n'
185-
'elif foo(a := n+1):\n'
186-
' print(n)\n',
187-
'n = 10\n'
188-
'if n > np.sin(foo.bar.quox):\n'
189-
' print(n)\n',
182+
'def foo():\n'
183+
' n = 10\n'
184+
' if True:\n'
185+
' pass\n'
186+
' elif foo(a := n+1):\n'
187+
' print(n)\n',
188+
'def foo():\n'
189+
' n = 10\n'
190+
' if n > np.sin(foo.bar.quox):\n'
191+
' print(n)\n',
192+
'def foo():\n'
193+
' n = 10\n'
194+
' if True or n > 3:\n'
195+
' print(n)\n',
190196
],
191197
)
192198
def test_noop(src: str) -> None:

0 commit comments

Comments
 (0)