@@ -277,7 +277,7 @@ def if_statement(
277277 branch = self .helper ("branch" , condition )
278278 self .statement ("with" , branch , "as" , "_bolt_condition" , lineno = lineno )
279279 with self .block ():
280- self .statement (f "if _bolt_condition" )
280+ self .statement ("if" , " _bolt_condition" )
281281 with self .block ():
282282 yield
283283 self .condition_inverse = inverse
@@ -304,7 +304,7 @@ def dup(self, target: str, *, lineno: Any = None) -> str:
304304 dup = self .make_variable ()
305305 value = self .helper ("get_dup" , target )
306306 self .statement (f"{ dup } = { value } " , lineno = lineno )
307- self .statement (f "if { dup } is not None" )
307+ self .statement ("if" , f" { dup } is not None" )
308308 with self .block ():
309309 self .statement (f"{ target } = { dup } ()" )
310310 return dup
@@ -314,13 +314,13 @@ def rebind(self, target: str, op: str, value: str, *, lineno: Any = None):
314314 rebind = self .helper ("get_rebind" , target )
315315 self .statement (f"_bolt_rebind = { rebind } " , lineno = lineno )
316316 self .statement (f"{ target } { op } { value } " )
317- self .statement (f "if _bolt_rebind is not None" )
317+ self .statement ("if" , " _bolt_rebind is not None" )
318318 with self .block ():
319319 self .statement (f"{ target } = _bolt_rebind({ target } )" )
320320
321321 def rebind_dup (self , target : str , dup : str , value : str , * , lineno : Any = None ):
322322 """Emit __rebind__() if target was __dup__()."""
323- self .statement (f "if { dup } is not None" )
323+ self .statement ("if" , f" { dup } is not None" )
324324 with self .block ():
325325 self .rebind (target , "=" , value , lineno = lineno )
326326 self .statement ("else" )
@@ -345,7 +345,9 @@ class WithStatementFusion:
345345 @classmethod
346346 def finalize (cls , acc : Accumulator ):
347347 with_statement_fusion = cls ()
348- acc .statements = [with_statement_fusion .fuse (statement , acc ) for statement in acc .statements ]
348+ acc .statements = [
349+ with_statement_fusion .fuse (statement , acc ) for statement in acc .statements
350+ ]
349351
350352 def convert (self , statement : CodegenStatement , exit_stack : str ) -> CodegenStatement :
351353 code = (f"{ exit_stack } .enter_context({ statement .code [1 ]} )" ,)
@@ -356,29 +358,31 @@ def convert(self, statement: CodegenStatement, exit_stack: str) -> CodegenStatem
356358 def fuse (self , statement : CodegenStatement , acc : Accumulator ) -> CodegenStatement :
357359 children = [self .fuse (child , acc ) for child in statement .children ]
358360
359- if statement .code [0 ] == "with" and children [ - 1 ]. code [ 0 ] = = "with" :
360- nested_statement = children . pop ( )
361+ if statement .code [0 ] ! = "with" :
362+ return replace ( statement , children = children )
361363
362- if nested_statement .code [1 ] == acc .helper ("exit_stack" ):
363- exit_stack = nested_statement .code [3 ]
364- code = nested_statement .code
365- else :
366- exit_stack = f"_bolt_fused_with_statement{ self .counter } "
367- self .counter += 1
368- code = ("with" , acc .helper ("exit_stack" ), "as" , exit_stack )
369- children .append (self .convert (nested_statement , exit_stack ))
370-
371- return replace (
372- statement ,
373- code = code ,
374- children = [
375- self .convert (statement , exit_stack ),
376- * children ,
377- * nested_statement .children ,
378- ],
379- )
364+ nested_children = children
365+ while nested_children [- 1 ].code [0 ] == "if" :
366+ nested_children = nested_children [- 1 ].children
367+
368+ if nested_children [- 1 ].code [0 ] != "with" :
369+ return replace (statement , children = children )
370+
371+ nested_statement = nested_children .pop ()
380372
381- return replace (statement , children = children )
373+ if nested_statement .code [1 ] == acc .helper ("exit_stack" ):
374+ exit_stack = nested_statement .code [3 ]
375+ code = nested_statement .code
376+ else :
377+ exit_stack = f"_bolt_fused_with_statement{ self .counter } "
378+ self .counter += 1
379+ code = ("with" , acc .helper ("exit_stack" ), "as" , exit_stack )
380+ nested_children .append (self .convert (nested_statement , exit_stack ))
381+
382+ children .insert (0 , self .convert (statement , exit_stack ))
383+ nested_children .extend (nested_statement .children )
384+
385+ return replace (statement , code = code , children = children )
382386
383387
384388@dataclass
@@ -711,7 +715,7 @@ def memo(
711715 acc .header [storage ] = "None"
712716 if not acc .root_scope :
713717 acc .statement (f"global { storage } " , lineno = node )
714- acc .statement (f "if { storage } is None" )
718+ acc .statement ("if" , f" { storage } is None" )
715719 with acc .block ():
716720 acc .statement (
717721 f"{ storage } = _bolt_runtime.memo.registry[__file__][{ acc .make_ref (node )} , { file_index } ]"
@@ -723,7 +727,7 @@ def memo(
723727 invocation = f"_bolt_memo_invocation_{ node .persistent_id .hex } "
724728 acc .statement (f"{ invocation } = { storage } [({ path } , { ' ' .join (keys )} )]" )
725729
726- acc .statement (f "if { invocation } .cached" )
730+ acc .statement ("if" , f" { invocation } .cached" )
727731 with acc .block ():
728732 acc .statement (f"_bolt_runtime.memo.restore(_bolt_runtime, { invocation } )" )
729733 if cached_identifiers :
@@ -798,7 +802,7 @@ def function(
798802
799803 for arg in signature .arguments :
800804 if isinstance (arg , AstFunctionSignatureArgument ) and arg .default :
801- acc .statement (f "if { arg .name } is { acc .missing ()} " )
805+ acc .statement ("if" , f" { arg .name } is { acc .missing ()} " )
802806 with acc .block ():
803807 value = yield from visit_single (arg .default , required = True )
804808 acc .statement (f"{ arg .name } = { value } " )
@@ -1023,7 +1027,9 @@ def while_statement(
10231027 with acc .block ():
10241028 acc .statement ("_bolt_runtime.commands.extend(_bolt_condition_commands)" )
10251029
1026- acc .statement ("if not _bolt_loop_overridden" , lineno = node .arguments [0 ])
1030+ acc .statement (
1031+ "if" , "not _bolt_loop_overridden" , lineno = node .arguments [0 ]
1032+ )
10271033 with acc .block ():
10281034 acc .statement (f"{ condition } = bool({ condition } )" )
10291035
0 commit comments