22
33import argparse
44import ast
5+ import dataclasses
56import os
67import pathlib
78import re
3031)
3132
3233
34+ @dataclasses .dataclass
35+ class Config :
36+ line_length : int
37+ unsafe : bool = False
38+
39+
3340def name_lineno_coloffset_iterable (
3441 tokens : Iterable [Token ],
3542) -> list [tuple [str , int , int ]]:
@@ -193,6 +200,7 @@ def related_vars_are_unused(
193200
194201def visit_function_def (
195202 node : ast .FunctionDef ,
203+ config : Config ,
196204) -> list [tuple [Token , Token ]]:
197205 names = set ()
198206 assignments : set [Token ] = set ()
@@ -204,7 +212,7 @@ def visit_function_def(
204212 related_vars : dict [str , list [Token ]] = {}
205213 in_body_vars : dict [Token , set [Token ]] = {}
206214
207- for _node in node .body :
215+ for _node in ast . walk ( node ) if config . unsafe else node .body :
208216 if isinstance (_node , ast .Assign ):
209217 process_assign (_node , assignments , related_vars )
210218 elif isinstance (_node , ast .If ):
@@ -255,7 +263,7 @@ def visit_function_def(
255263
256264def auto_walrus (
257265 content : str ,
258- line_length : int ,
266+ config : Config ,
259267) -> str | None :
260268 lines = content .splitlines ()
261269 try :
@@ -266,7 +274,7 @@ def auto_walrus(
266274 walruses = []
267275 for node in ast .walk (tree ):
268276 if isinstance (node , ast .FunctionDef ):
269- walruses .extend (visit_function_def (node ))
277+ walruses .extend (visit_function_def (node , config ))
270278 lines_to_remove = []
271279 walruses = sorted (walruses , key = lambda x : (- x [1 ][1 ], - x [1 ][2 ]))
272280
@@ -290,7 +298,7 @@ def auto_walrus(
290298 line_with_walrus = left_bit + replace + right_bit
291299 else :
292300 line_with_walrus = left_bit + "(" + replace + ")" + right_bit
293- if len (line_with_walrus ) > line_length :
301+ if len (line_with_walrus ) > config . line_length :
294302 # don't rewrite if it would split over multiple lines
295303 continue
296304 # replace assignment
@@ -361,18 +369,24 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
361369 required = False ,
362370 default = r"^$" ,
363371 )
372+ parser .add_argument (
373+ "--unsafe" ,
374+ action = "store_true" ,
375+ help = "Also process if statements inside other blocks (like for loops)" ,
376+ )
364377 # black formatter's default
365378 parser .add_argument ("--line-length" , type = int , default = 88 )
366379 args = parser .parse_args (argv )
367380 paths = [pathlib .Path (path ).resolve () for path in args .paths ]
368381
369382 # Update defaults from pyproject.toml if present
370- config = {k .replace ("-" , "_" ): v for k , v in _get_config (paths ).items ()}
371- parser .set_defaults (** config )
383+ defaults = {k .replace ("-" , "_" ): v for k , v in _get_config (paths ).items ()}
384+ parser .set_defaults (** defaults )
372385 args = parser .parse_args (argv )
373386
374387 ret = 0
375388
389+ config = Config (line_length = args .line_length , unsafe = args .unsafe )
376390 for path in paths :
377391 if path .is_file ():
378392 filepaths = iter ((path ,))
@@ -392,10 +406,7 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
392406 content = fd .read ()
393407 except UnicodeDecodeError :
394408 continue
395- new_content = auto_walrus (
396- content ,
397- line_length = args .line_length ,
398- )
409+ new_content = auto_walrus (content , config )
399410 if new_content is not None and content != new_content :
400411 sys .stdout .write (f"Rewriting { filepath } \n " )
401412 with open (filepath , "w" , encoding = "utf-8" ) as fd :
0 commit comments