@@ -4327,71 +4327,114 @@ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
43274327 # try to replace wp.static() expressions by their evaluated value if the
43284328 # expression can be evaluated
43294329 def replace_static_expressions (adj ):
4330- class StaticExpressionReplacer (ast .NodeTransformer ):
4331- def __init__ (self ):
4332- # Track loop variable names from enclosing for loops. This prevents
4333- # wp.static() from capturing a global variable that shadows a loop variable.
4334- # Uses a counter (not a set) to handle nested loops that reuse the same variable name.
4335- self .loop_vars = {}
4336-
4337- def visit_For (self , node ):
4338- # Track loop variable while visiting loop body (simple names only;
4339- # tuple unpacking like `for x, y in ...` is rare in Warp kernels)
4340- var_name = node .target .id if isinstance (node .target , ast .Name ) else None
4341- if var_name :
4342- self .loop_vars [var_name ] = self .loop_vars .get (var_name , 0 ) + 1
4343- result = self .generic_visit (node )
4344- if var_name :
4345- self .loop_vars [var_name ] -= 1
4346- if self .loop_vars [var_name ] == 0 :
4347- del self .loop_vars [var_name ]
4348- return result
4349-
4350- def visit_Call (self , node ):
4351- func , _ = adj .resolve_static_expression (node .func , eval_types = False )
4352- if adj .is_static_expression (func ):
4353- # If the static expression references an enclosing loop variable,
4354- # defer evaluation to codegen time when the loop constant is available
4355- expr_node = node .args [0 ] if node .args else (node .keywords [0 ].value if node .keywords else None )
4356- if expr_node :
4357- referenced = {n .id for n in ast .walk (expr_node ) if isinstance (n , ast .Name )}
4358- if referenced & self .loop_vars .keys ():
4359- adj .has_unresolved_static_expressions = True
4360- return self .generic_visit (node )
4361-
4362- try :
4363- # the static expression will execute as long as the static expression is valid and
4364- # only depends on global or captured variables
4365- obj , code = adj .evaluate_static_expression (node )
4366- if code is not None :
4367- adj .resolved_static_expressions [code ] = obj
4368- if isinstance (obj , warp ._src .context .Function ):
4369- name_node = ast .Name ("__warp_func__" )
4370- # we add a pointer to the Warp function here so that we can refer to it later at
4371- # codegen time (note that the function key itself is not sufficient to uniquely
4372- # identify the function, as the function may be redefined between the current time
4373- # of wp.static() declaration and the time of codegen during module building)
4374- name_node .warp_func = obj
4375- return ast .copy_location (name_node , node )
4376- else :
4377- return ast .copy_location (ast .Constant (value = obj ), node )
4378- except Exception :
4379- # Ignoring failing static expressions should generally not be an issue because only
4380- # one of these cases should be possible:
4381- # 1) the static expression itself is invalid code, in which case the module cannot be
4382- # built all,
4383- # 2) the static expression contains a reference to a local (even if constant) variable
4384- # (and is therefore not executable and raises this exception), in which
4385- # case changing the constant, or the code affecting this constant, would lead to
4386- # a different module hash anyway.
4387- # In any case, we mark this Adjoint to have unresolvable static expressions.
4388- # This will trigger a code generation step even if the module hash is unchanged.
4330+ # ``visit_For`` and ``visit_Call`` below are the upstream
4331+ # ``ast.NodeTransformer`` subclass's methods lifted into closures —
4332+ # bodies are unchanged except for the trailing ``self.generic_visit(node)``,
4333+ # which becomes ``_walk_children(node)`` in ``visit_For`` and ``None`` in
4334+ # ``visit_Call`` (where ``None`` means "no replacement, recurse normally").
4335+ # ``_walk_children`` replaces ``generic_visit``: same DFS over
4336+ # ``node._fields``, but dispatching Calls/Fors inline by class identity
4337+ # (no ``'visit_' + cls.__name__`` + ``getattr``) and mutating list
4338+ # fields in place only when a replacement actually occurred.
4339+ # Replacements are collected as ``(container, key, new_node)`` and
4340+ # applied after the walk so the walk sees an unmutated tree.
4341+ loop_vars = {} # was: self.loop_vars
4342+ replacements = [] # (container, key, new_node); applied after the walk
4343+
4344+ def _walk_children (node ):
4345+ for field_name in node ._fields :
4346+ value = getattr (node , field_name , None )
4347+ if value is None :
4348+ continue
4349+ if type (value ) is list :
4350+ for i , child in enumerate (value ):
4351+ if not isinstance (child , ast .AST ):
4352+ continue
4353+ cls = type (child )
4354+ if cls is ast .Call :
4355+ result = visit_Call (child )
4356+ if result is not None :
4357+ replacements .append ((value , i , result ))
4358+ continue
4359+ elif cls is ast .For :
4360+ visit_For (child )
4361+ continue
4362+ _walk_children (child )
4363+ elif isinstance (value , ast .AST ):
4364+ cls = type (value )
4365+ if cls is ast .Call :
4366+ result = visit_Call (value )
4367+ if result is not None :
4368+ replacements .append ((node , field_name , result ))
4369+ continue
4370+ elif cls is ast .For :
4371+ visit_For (value )
4372+ continue
4373+ _walk_children (value )
4374+
4375+ def visit_For (node ):
4376+ # Track loop variable while visiting loop body (simple names only;
4377+ # tuple unpacking like `for x, y in ...` is rare in Warp kernels)
4378+ var_name = node .target .id if isinstance (node .target , ast .Name ) else None
4379+ if var_name :
4380+ loop_vars [var_name ] = loop_vars .get (var_name , 0 ) + 1
4381+ _walk_children (node ) # was: self.generic_visit(node)
4382+ if var_name :
4383+ loop_vars [var_name ] -= 1
4384+ if loop_vars [var_name ] == 0 :
4385+ del loop_vars [var_name ]
4386+
4387+ def visit_Call (node ):
4388+ func , _ = adj .resolve_static_expression (node .func , eval_types = False )
4389+ if adj .is_static_expression (func ):
4390+ # If the static expression references an enclosing loop variable,
4391+ # defer evaluation to codegen time when the loop constant is available
4392+ expr_node = node .args [0 ] if node .args else (node .keywords [0 ].value if node .keywords else None )
4393+ if expr_node :
4394+ referenced = {n .id for n in ast .walk (expr_node ) if isinstance (n , ast .Name )}
4395+ if referenced & loop_vars .keys ():
43894396 adj .has_unresolved_static_expressions = True
4390- pass
4391-
4392- return self .generic_visit (node )
4397+ return None # was: return self.generic_visit(node)
43934398
4394- adj .tree = StaticExpressionReplacer ().visit (adj .tree )
4399+ try :
4400+ # the static expression will execute as long as the static expression is valid and
4401+ # only depends on global or captured variables
4402+ obj , code = adj .evaluate_static_expression (node )
4403+ if code is not None :
4404+ adj .resolved_static_expressions [code ] = obj
4405+ if isinstance (obj , warp ._src .context .Function ):
4406+ name_node = ast .Name ("__warp_func__" )
4407+ # we add a pointer to the Warp function here so that we can refer to it later at
4408+ # codegen time (note that the function key itself is not sufficient to uniquely
4409+ # identify the function, as the function may be redefined between the current time
4410+ # of wp.static() declaration and the time of codegen during module building)
4411+ name_node .warp_func = obj
4412+ return ast .copy_location (name_node , node )
4413+ else :
4414+ return ast .copy_location (ast .Constant (value = obj ), node )
4415+ except Exception :
4416+ # Ignoring failing static expressions should generally not be an issue because only
4417+ # one of these cases should be possible:
4418+ # 1) the static expression itself is invalid code, in which case the module cannot be
4419+ # built all,
4420+ # 2) the static expression contains a reference to a local (even if constant) variable
4421+ # (and is therefore not executable and raises this exception), in which
4422+ # case changing the constant, or the code affecting this constant, would lead to
4423+ # a different module hash anyway.
4424+ # In any case, we mark this Adjoint to have unresolvable static expressions.
4425+ # This will trigger a code generation step even if the module hash is unchanged.
4426+ adj .has_unresolved_static_expressions = True
4427+
4428+ return None # was: return self.generic_visit(node)
4429+
4430+ # Walk the tree, then apply replacements in one pass. ``adj.tree`` is
4431+ # always a Module, so we go straight into ``_walk_children``.
4432+ _walk_children (adj .tree )
4433+ for container , key , new_node in replacements :
4434+ if isinstance (container , list ):
4435+ container [key ] = new_node
4436+ else :
4437+ setattr (container , key , new_node )
43954438
43964439 # Evaluates a static expression that does not depend on runtime values
43974440 # if eval_types is True, try resolving the path using evaluated type information as well
0 commit comments