3434@dataclasses .dataclass
3535class Config :
3636 line_length : int
37+ unsafe : bool = False
3738
3839
3940def name_lineno_coloffset_iterable (
@@ -199,6 +200,7 @@ def related_vars_are_unused(
199200
200201def visit_function_def (
201202 node : ast .FunctionDef ,
203+ config : Config ,
202204) -> list [tuple [Token , Token ]]:
203205 names = set ()
204206 assignments : set [Token ] = set ()
@@ -210,7 +212,7 @@ def visit_function_def(
210212 related_vars : dict [str , list [Token ]] = {}
211213 in_body_vars : dict [Token , set [Token ]] = {}
212214
213- for _node in node .body :
215+ for _node in ast . walk ( node ) if config . unsafe else node .body :
214216 if isinstance (_node , ast .Assign ):
215217 process_assign (_node , assignments , related_vars )
216218 elif isinstance (_node , ast .If ):
@@ -272,7 +274,7 @@ def auto_walrus(
272274 walruses = []
273275 for node in ast .walk (tree ):
274276 if isinstance (node , ast .FunctionDef ):
275- walruses .extend (visit_function_def (node ))
277+ walruses .extend (visit_function_def (node , config ))
276278 lines_to_remove = []
277279 walruses = sorted (walruses , key = lambda x : (- x [1 ][1 ], - x [1 ][2 ]))
278280
@@ -367,6 +369,11 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
367369 required = False ,
368370 default = r"^$" ,
369371 )
372+ parser .add_argument (
373+ "--unsafe" ,
374+ action = "store_true" ,
375+ help = "Also process if statements inside other blocks (like for loops)" ,
376+ )
370377 # black formatter's default
371378 parser .add_argument ("--line-length" , type = int , default = 88 )
372379 args = parser .parse_args (argv )
@@ -398,7 +405,10 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
398405 content = fd .read ()
399406 except UnicodeDecodeError :
400407 continue
401- new_content = auto_walrus (content , args .line_length )
408+ new_content = auto_walrus (
409+ content ,
410+ Config (line_length = args .line_length , unsafe = args .unsafe ),
411+ )
402412 if new_content is not None and content != new_content :
403413 sys .stdout .write (f"Rewriting { filepath } \n " )
404414 with open (filepath , "w" , encoding = "utf-8" ) as fd :
0 commit comments