Skip to content

Commit bdfbe60

Browse files
committed
Add --unsafe
Refs #84
1 parent 8e3b19e commit bdfbe60

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

auto_walrus.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
@dataclasses.dataclass
3535
class Config:
3636
line_length: int
37+
unsafe: bool = False
3738

3839

3940
def name_lineno_coloffset_iterable(
@@ -199,6 +200,7 @@ def related_vars_are_unused(
199200

200201
def visit_function_def(
201202
node: ast.FunctionDef,
203+
config: Config,
202204
) -> list[tuple[Token, Token]]:
203205
names = set()
204206
assignments: set[Token] = set()
@@ -210,7 +212,7 @@ def visit_function_def(
210212
related_vars: dict[str, list[Token]] = {}
211213
in_body_vars: dict[Token, set[Token]] = {}
212214

213-
for _node in node.body:
215+
for _node in ast.walk(node) if config.unsafe else node.body:
214216
if isinstance(_node, ast.Assign):
215217
process_assign(_node, assignments, related_vars)
216218
elif isinstance(_node, ast.If):
@@ -272,7 +274,7 @@ def auto_walrus(
272274
walruses = []
273275
for node in ast.walk(tree):
274276
if isinstance(node, ast.FunctionDef):
275-
walruses.extend(visit_function_def(node))
277+
walruses.extend(visit_function_def(node, config))
276278
lines_to_remove = []
277279
walruses = sorted(walruses, key=lambda x: (-x[1][1], -x[1][2]))
278280

@@ -367,6 +369,11 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
367369
required=False,
368370
default=r"^$",
369371
)
372+
parser.add_argument(
373+
"--unsafe",
374+
action="store_true",
375+
help="Also process if statements inside other blocks (like for loops)",
376+
)
370377
# black formatter's default
371378
parser.add_argument("--line-length", type=int, default=88)
372379
args = parser.parse_args(argv)
@@ -398,7 +405,10 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
398405
content = fd.read()
399406
except UnicodeDecodeError:
400407
continue
401-
new_content = auto_walrus(content, args.line_length)
408+
new_content = auto_walrus(
409+
content,
410+
Config(line_length=args.line_length, unsafe=args.unsafe),
411+
)
402412
if new_content is not None and content != new_content:
403413
sys.stdout.write(f"Rewriting {filepath}\n")
404414
with open(filepath, "w", encoding="utf-8") as fd:

tests/main_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99

10+
from auto_walrus import Config
1011
from auto_walrus import auto_walrus
1112
from auto_walrus import main
1213

@@ -129,6 +130,20 @@ def test_noop(src: str) -> None:
129130
assert ret is None
130131

131132

133+
@pytest.mark.parametrize(
134+
("src", "expected"),
135+
[
136+
(
137+
'def foo(data):\n if True:\n foo = data.get("blah")\n if foo:\n return foo\n return data',
138+
'def foo(data):\n if True:\n if (foo := data.get("blah")):\n return foo\n return data',
139+
),
140+
],
141+
)
142+
def test_rewrite_unsafe(src: str, expected: str) -> None:
143+
ret = auto_walrus(src, Config(line_length=88, unsafe=True))
144+
assert ret == expected
145+
146+
132147
ProjectDirT = Tuple[pathlib.Path, List[pathlib.Path]]
133148

134149
SRC_ORIG = "def foo():\n a = 0\n if a:\n print(a)\n"

0 commit comments

Comments
 (0)