3434@dataclasses .dataclass
3535class Config :
3636 line_length : int
37+ unsafe : bool = False
3738
3839
3940def name_lineno_coloffset_iterable (
@@ -199,6 +200,7 @@ def related_vars_are_unused(
199200
200201def visit_function_def (
201202 node : ast .FunctionDef ,
203+ config : Config ,
202204) -> list [tuple [Token , Token ]]:
203205 names = set ()
204206 assignments : set [Token ] = set ()
@@ -210,7 +212,7 @@ def visit_function_def(
210212 related_vars : dict [str , list [Token ]] = {}
211213 in_body_vars : dict [Token , set [Token ]] = {}
212214
213- for _node in node .body :
215+ for _node in ast . walk ( node ) if config . unsafe else node .body :
214216 if isinstance (_node , ast .Assign ):
215217 process_assign (_node , assignments , related_vars )
216218 elif isinstance (_node , ast .If ):
@@ -276,7 +278,7 @@ def auto_walrus(
276278 walruses = []
277279 for node in ast .walk (tree ):
278280 if isinstance (node , ast .FunctionDef ):
279- walruses .extend (visit_function_def (node ))
281+ walruses .extend (visit_function_def (node , config ))
280282 lines_to_remove = []
281283 walruses = sorted (walruses , key = lambda x : (- x [1 ][1 ], - x [1 ][2 ]))
282284
@@ -371,6 +373,11 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
371373 required = False ,
372374 default = r"^$" ,
373375 )
376+ parser .add_argument (
377+ "--unsafe" ,
378+ action = "store_true" ,
379+ help = "Also process if statements inside other blocks (like for loops)" ,
380+ )
374381 # black formatter's default
375382 parser .add_argument ("--line-length" , type = int , default = 88 )
376383 args = parser .parse_args (argv )
@@ -402,7 +409,10 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
402409 content = fd .read ()
403410 except UnicodeDecodeError :
404411 continue
405- new_content = auto_walrus (content , args .line_length )
412+ new_content = auto_walrus (
413+ content ,
414+ Config (line_length = args .line_length , unsafe = args .unsafe ),
415+ )
406416 if new_content is not None and content != new_content :
407417 sys .stdout .write (f"Rewriting { filepath } \n " )
408418 with open (filepath , "w" , encoding = "utf-8" ) as fd :
0 commit comments