Skip to content

Commit efa4898

Browse files
authored
Merge pull request #85 from akx/unsafe-mode
Add --unsafe
2 parents 2eec88b + 52ebbdc commit efa4898

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

auto_walrus.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import ast
5+
import dataclasses
56
import os
67
import pathlib
78
import re
@@ -30,6 +31,12 @@
3031
)
3132

3233

34+
@dataclasses.dataclass
35+
class Config:
36+
line_length: int
37+
unsafe: bool = False
38+
39+
3340
def name_lineno_coloffset_iterable(
3441
tokens: Iterable[Token],
3542
) -> list[tuple[str, int, int]]:
@@ -193,6 +200,7 @@ def related_vars_are_unused(
193200

194201
def visit_function_def(
195202
node: ast.FunctionDef,
203+
config: Config,
196204
) -> list[tuple[Token, Token]]:
197205
names = set()
198206
assignments: set[Token] = set()
@@ -204,7 +212,7 @@ def visit_function_def(
204212
related_vars: dict[str, list[Token]] = {}
205213
in_body_vars: dict[Token, set[Token]] = {}
206214

207-
for _node in node.body:
215+
for _node in ast.walk(node) if config.unsafe else node.body:
208216
if isinstance(_node, ast.Assign):
209217
process_assign(_node, assignments, related_vars)
210218
elif isinstance(_node, ast.If):
@@ -255,7 +263,7 @@ def visit_function_def(
255263

256264
def auto_walrus(
257265
content: str,
258-
line_length: int,
266+
config: Config,
259267
) -> str | None:
260268
lines = content.splitlines()
261269
try:
@@ -266,7 +274,7 @@ def auto_walrus(
266274
walruses = []
267275
for node in ast.walk(tree):
268276
if isinstance(node, ast.FunctionDef):
269-
walruses.extend(visit_function_def(node))
277+
walruses.extend(visit_function_def(node, config))
270278
lines_to_remove = []
271279
walruses = sorted(walruses, key=lambda x: (-x[1][1], -x[1][2]))
272280

@@ -290,7 +298,7 @@ def auto_walrus(
290298
line_with_walrus = left_bit + replace + right_bit
291299
else:
292300
line_with_walrus = left_bit + "(" + replace + ")" + right_bit
293-
if len(line_with_walrus) > line_length:
301+
if len(line_with_walrus) > config.line_length:
294302
# don't rewrite if it would split over multiple lines
295303
continue
296304
# replace assignment
@@ -361,18 +369,24 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
361369
required=False,
362370
default=r"^$",
363371
)
372+
parser.add_argument(
373+
"--unsafe",
374+
action="store_true",
375+
help="Also process if statements inside other blocks (like for loops)",
376+
)
364377
# black formatter's default
365378
parser.add_argument("--line-length", type=int, default=88)
366379
args = parser.parse_args(argv)
367380
paths = [pathlib.Path(path).resolve() for path in args.paths]
368381

369382
# Update defaults from pyproject.toml if present
370-
config = {k.replace("-", "_"): v for k, v in _get_config(paths).items()}
371-
parser.set_defaults(**config)
383+
defaults = {k.replace("-", "_"): v for k, v in _get_config(paths).items()}
384+
parser.set_defaults(**defaults)
372385
args = parser.parse_args(argv)
373386

374387
ret = 0
375388

389+
config = Config(line_length=args.line_length, unsafe=args.unsafe)
376390
for path in paths:
377391
if path.is_file():
378392
filepaths = iter((path,))
@@ -392,10 +406,7 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
392406
content = fd.read()
393407
except UnicodeDecodeError:
394408
continue
395-
new_content = auto_walrus(
396-
content,
397-
line_length=args.line_length,
398-
)
409+
new_content = auto_walrus(content, config)
399410
if new_content is not None and content != new_content:
400411
sys.stdout.write(f"Rewriting {filepath}\n")
401412
with open(filepath, "w", encoding="utf-8") as fd:

tests/main_test.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99

10+
from auto_walrus import Config
1011
from auto_walrus import auto_walrus
1112
from auto_walrus import main
1213

@@ -89,7 +90,7 @@
8990
],
9091
)
9192
def test_rewrite(src: str, expected: str) -> None:
92-
ret = auto_walrus(src, 88)
93+
ret = auto_walrus(src, Config(line_length=88))
9394
assert ret == expected
9495

9596

@@ -125,10 +126,24 @@ def test_rewrite(src: str, expected: str) -> None:
125126
],
126127
)
127128
def test_noop(src: str) -> None:
128-
ret = auto_walrus(src, 40)
129+
ret = auto_walrus(src, Config(line_length=40))
129130
assert ret is None
130131

131132

133+
@pytest.mark.parametrize(
134+
("src", "expected"),
135+
[
136+
(
137+
'def foo(data):\n if True:\n foo = data.get("blah")\n if foo:\n return foo\n return data',
138+
'def foo(data):\n if True:\n if (foo := data.get("blah")):\n return foo\n return data',
139+
),
140+
],
141+
)
142+
def test_rewrite_unsafe(src: str, expected: str) -> None:
143+
ret = auto_walrus(src, Config(line_length=88, unsafe=True))
144+
assert ret == expected
145+
146+
132147
ProjectDirT = Tuple[pathlib.Path, List[pathlib.Path]]
133148

134149
SRC_ORIG = "def foo():\n a = 0\n if a:\n print(a)\n"

0 commit comments

Comments
 (0)