Skip to content

Commit a32e144

Browse files
authored
Merge pull request #90 from akx/avoid-duplicate-walrus-removal
Fix incorrect rewrite with --unsafe on nested functions
2 parents d1f1f71 + c9fe673 commit a32e144

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

auto_walrus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ def auto_walrus(
271271
except SyntaxError: # pragma: no cover
272272
return None
273273

274-
walruses = []
274+
walrus_set: set[tuple[Token, Token]] = set()
275275
for node in ast.walk(tree):
276276
if isinstance(node, ast.FunctionDef):
277-
walruses.extend(visit_function_def(node, config))
277+
walrus_set.update(visit_function_def(node, config))
278278
lines_to_remove = []
279-
walruses = sorted(walruses, key=lambda x: (-x[1][1], -x[1][2]))
279+
walruses = sorted(walrus_set, key=lambda x: (-x[1][1], -x[1][2]))
280280

281281
if not walruses:
282282
return None

tests/main_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ def test_noop(src: str) -> None:
137137
'def foo(data):\n if True:\n foo = data.get("blah")\n if foo:\n return foo\n return data',
138138
'def foo(data):\n if True:\n if (foo := data.get("blah")):\n return foo\n return data',
139139
),
140+
# Nested functions - issue #89
141+
(
142+
"def foo():\n"
143+
" def bar():\n"
144+
" def quox():\n"
145+
" conn_time_zone = fetch_rel_time_zone(df.native)\n"
146+
" if conn_time_zone != time_zone:\n"
147+
" print(conn_time_zone)\n",
148+
"def foo():\n"
149+
" def bar():\n"
150+
" def quox():\n"
151+
" if (conn_time_zone := fetch_rel_time_zone(df.native)) != time_zone:\n"
152+
" print(conn_time_zone)\n",
153+
),
140154
],
141155
)
142156
def test_rewrite_unsafe(src: str, expected: str) -> None:

0 commit comments

Comments
 (0)