3434@dataclasses .dataclass
3535class Config :
3636 line_length : int
37+ unsafe : bool = False
3738
3839
3940def name_lineno_coloffset_iterable (
@@ -199,6 +200,7 @@ def related_vars_are_unused(
199200
200201def visit_function_def (
201202 node : ast .FunctionDef ,
203+ config : Config ,
202204) -> list [tuple [Token , Token ]]:
203205 names = set ()
204206 assignments : set [Token ] = set ()
@@ -210,7 +212,7 @@ def visit_function_def(
210212 related_vars : dict [str , list [Token ]] = {}
211213 in_body_vars : dict [Token , set [Token ]] = {}
212214
213- for _node in node .body :
215+ for _node in ast . walk ( node ) if config . unsafe else node .body :
214216 if isinstance (_node , ast .Assign ):
215217 process_assign (_node , assignments , related_vars )
216218 elif isinstance (_node , ast .If ):
@@ -272,7 +274,7 @@ def auto_walrus(
272274 walruses = []
273275 for node in ast .walk (tree ):
274276 if isinstance (node , ast .FunctionDef ):
275- walruses .extend (visit_function_def (node ))
277+ walruses .extend (visit_function_def (node , config ))
276278 lines_to_remove = []
277279 walruses = sorted (walruses , key = lambda x : (- x [1 ][1 ], - x [1 ][2 ]))
278280
@@ -367,6 +369,11 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
367369 required = False ,
368370 default = r"^$" ,
369371 )
372+ parser .add_argument (
373+ "--unsafe" ,
374+ action = "store_true" ,
375+ help = "Also process if statements inside other blocks (like for loops)" ,
376+ )
370377 # black formatter's default
371378 parser .add_argument ("--line-length" , type = int , default = 88 )
372379 args = parser .parse_args (argv )
@@ -379,7 +386,7 @@ def main(argv: Sequence[str] | None = None) -> int: # pragma: no cover
379386
380387 ret = 0
381388
382- config = Config (line_length = args .line_length )
389+ config = Config (line_length = args .line_length , unsafe = args . unsafe )
383390 for path in paths :
384391 if path .is_file ():
385392 filepaths = iter ((path ,))
0 commit comments