Skip to content

Commit 81f95f8

Browse files
adenzler-nvidiashi-eric
authored andcommitted
Speed up Adjoint.replace_static_expressions [GH-1486]
* Speed up replace_static_expressions The upstream replace_static_expressions traverses the kernel AST via ast.NodeTransformer, whose generic_visit rebuilds every field of every visited node — most of that work is wasted on kernels where almost no Call is wp.static. Replace the NodeTransformer with a hand-rolled DFS walker that visits the same nodes upstream did and calls resolve_static_expression on the same Calls, but collects replacements during the walk and applies them afterward, so the surrounding tree is never rebuilt. Provably equivalent to upstream: every Call still goes through resolve_static_expression, every successfully-resolved static still gets the same replacement node, and loop-variable tracking around `for` bodies matches the original. mjwarp factory bench (7 runs of 20 iterations, median ms): upstream/main: 310.6 this MR: 283.1 delta: -27.5 ms (-9%) Existing test_codegen and test_static suites pass unchanged (76 tests). One new test pins the walker rewrite: - test_replace_static_expressions_skips_node_transformer: mocks ast.NodeTransformer.visit to raise, confirms the walker no longer routes through NodeTransformer. Signed-off-by: Alain Denzler <adenzler@nvidia.com> Approved-by: Nicolas Capens <ncapens@nvidia.com> Approved-by: Eric Shi <ershi@nvidia.com> See merge request omniverse/warp!2439
1 parent 71a66cc commit 81f95f8

3 files changed

Lines changed: 164 additions & 63 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@
139139
([GH-1431](https://github.com/NVIDIA/warp/issues/1431)).
140140
- Improve CPU and CUDA module load failure messages with the active device,
141141
block dimension, and module hash.
142+
- Speed up Warp kernel creation by avoiding redundant work in static-expression
143+
rewriting; visible on kernels with many non-`wp.static` calls
144+
([GH-1486](https://github.com/NVIDIA/warp/issues/1486)).
142145

143146
### Fixed
144147

warp/_src/codegen.py

Lines changed: 106 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4327,71 +4327,114 @@ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
43274327
# try to replace wp.static() expressions by their evaluated value if the
43284328
# expression can be evaluated
43294329
def replace_static_expressions(adj):
4330-
class StaticExpressionReplacer(ast.NodeTransformer):
4331-
def __init__(self):
4332-
# Track loop variable names from enclosing for loops. This prevents
4333-
# wp.static() from capturing a global variable that shadows a loop variable.
4334-
# Uses a counter (not a set) to handle nested loops that reuse the same variable name.
4335-
self.loop_vars = {}
4336-
4337-
def visit_For(self, node):
4338-
# Track loop variable while visiting loop body (simple names only;
4339-
# tuple unpacking like `for x, y in ...` is rare in Warp kernels)
4340-
var_name = node.target.id if isinstance(node.target, ast.Name) else None
4341-
if var_name:
4342-
self.loop_vars[var_name] = self.loop_vars.get(var_name, 0) + 1
4343-
result = self.generic_visit(node)
4344-
if var_name:
4345-
self.loop_vars[var_name] -= 1
4346-
if self.loop_vars[var_name] == 0:
4347-
del self.loop_vars[var_name]
4348-
return result
4349-
4350-
def visit_Call(self, node):
4351-
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
4352-
if adj.is_static_expression(func):
4353-
# If the static expression references an enclosing loop variable,
4354-
# defer evaluation to codegen time when the loop constant is available
4355-
expr_node = node.args[0] if node.args else (node.keywords[0].value if node.keywords else None)
4356-
if expr_node:
4357-
referenced = {n.id for n in ast.walk(expr_node) if isinstance(n, ast.Name)}
4358-
if referenced & self.loop_vars.keys():
4359-
adj.has_unresolved_static_expressions = True
4360-
return self.generic_visit(node)
4361-
4362-
try:
4363-
# the static expression will execute as long as the static expression is valid and
4364-
# only depends on global or captured variables
4365-
obj, code = adj.evaluate_static_expression(node)
4366-
if code is not None:
4367-
adj.resolved_static_expressions[code] = obj
4368-
if isinstance(obj, warp._src.context.Function):
4369-
name_node = ast.Name("__warp_func__")
4370-
# we add a pointer to the Warp function here so that we can refer to it later at
4371-
# codegen time (note that the function key itself is not sufficient to uniquely
4372-
# identify the function, as the function may be redefined between the current time
4373-
# of wp.static() declaration and the time of codegen during module building)
4374-
name_node.warp_func = obj
4375-
return ast.copy_location(name_node, node)
4376-
else:
4377-
return ast.copy_location(ast.Constant(value=obj), node)
4378-
except Exception:
4379-
# Ignoring failing static expressions should generally not be an issue because only
4380-
# one of these cases should be possible:
4381-
# 1) the static expression itself is invalid code, in which case the module cannot be
4382-
# built all,
4383-
# 2) the static expression contains a reference to a local (even if constant) variable
4384-
# (and is therefore not executable and raises this exception), in which
4385-
# case changing the constant, or the code affecting this constant, would lead to
4386-
# a different module hash anyway.
4387-
# In any case, we mark this Adjoint to have unresolvable static expressions.
4388-
# This will trigger a code generation step even if the module hash is unchanged.
4330+
# ``visit_For`` and ``visit_Call`` below are the upstream
4331+
# ``ast.NodeTransformer`` subclass's methods lifted into closures —
4332+
# bodies are unchanged except for the trailing ``self.generic_visit(node)``,
4333+
# which becomes ``_walk_children(node)`` in ``visit_For`` and ``None`` in
4334+
# ``visit_Call`` (where ``None`` means "no replacement, recurse normally").
4335+
# ``_walk_children`` replaces ``generic_visit``: same DFS over
4336+
# ``node._fields``, but dispatching Calls/Fors inline by class identity
4337+
# (no ``'visit_' + cls.__name__`` + ``getattr``) and mutating list
4338+
# fields in place only when a replacement actually occurred.
4339+
# Replacements are collected as ``(container, key, new_node)`` and
4340+
# applied after the walk so the walk sees an unmutated tree.
4341+
loop_vars = {} # was: self.loop_vars
4342+
replacements = [] # (container, key, new_node); applied after the walk
4343+
4344+
def _walk_children(node):
4345+
for field_name in node._fields:
4346+
value = getattr(node, field_name, None)
4347+
if value is None:
4348+
continue
4349+
if type(value) is list:
4350+
for i, child in enumerate(value):
4351+
if not isinstance(child, ast.AST):
4352+
continue
4353+
cls = type(child)
4354+
if cls is ast.Call:
4355+
result = visit_Call(child)
4356+
if result is not None:
4357+
replacements.append((value, i, result))
4358+
continue
4359+
elif cls is ast.For:
4360+
visit_For(child)
4361+
continue
4362+
_walk_children(child)
4363+
elif isinstance(value, ast.AST):
4364+
cls = type(value)
4365+
if cls is ast.Call:
4366+
result = visit_Call(value)
4367+
if result is not None:
4368+
replacements.append((node, field_name, result))
4369+
continue
4370+
elif cls is ast.For:
4371+
visit_For(value)
4372+
continue
4373+
_walk_children(value)
4374+
4375+
def visit_For(node):
4376+
# Track loop variable while visiting loop body (simple names only;
4377+
# tuple unpacking like `for x, y in ...` is rare in Warp kernels)
4378+
var_name = node.target.id if isinstance(node.target, ast.Name) else None
4379+
if var_name:
4380+
loop_vars[var_name] = loop_vars.get(var_name, 0) + 1
4381+
_walk_children(node) # was: self.generic_visit(node)
4382+
if var_name:
4383+
loop_vars[var_name] -= 1
4384+
if loop_vars[var_name] == 0:
4385+
del loop_vars[var_name]
4386+
4387+
def visit_Call(node):
4388+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
4389+
if adj.is_static_expression(func):
4390+
# If the static expression references an enclosing loop variable,
4391+
# defer evaluation to codegen time when the loop constant is available
4392+
expr_node = node.args[0] if node.args else (node.keywords[0].value if node.keywords else None)
4393+
if expr_node:
4394+
referenced = {n.id for n in ast.walk(expr_node) if isinstance(n, ast.Name)}
4395+
if referenced & loop_vars.keys():
43894396
adj.has_unresolved_static_expressions = True
4390-
pass
4391-
4392-
return self.generic_visit(node)
4397+
return None # was: return self.generic_visit(node)
43934398

4394-
adj.tree = StaticExpressionReplacer().visit(adj.tree)
4399+
try:
4400+
# the static expression will execute as long as the static expression is valid and
4401+
# only depends on global or captured variables
4402+
obj, code = adj.evaluate_static_expression(node)
4403+
if code is not None:
4404+
adj.resolved_static_expressions[code] = obj
4405+
if isinstance(obj, warp._src.context.Function):
4406+
name_node = ast.Name("__warp_func__")
4407+
# we add a pointer to the Warp function here so that we can refer to it later at
4408+
# codegen time (note that the function key itself is not sufficient to uniquely
4409+
# identify the function, as the function may be redefined between the current time
4410+
# of wp.static() declaration and the time of codegen during module building)
4411+
name_node.warp_func = obj
4412+
return ast.copy_location(name_node, node)
4413+
else:
4414+
return ast.copy_location(ast.Constant(value=obj), node)
4415+
except Exception:
4416+
# Ignoring failing static expressions should generally not be an issue because only
4417+
# one of these cases should be possible:
4418+
# 1) the static expression itself is invalid code, in which case the module cannot be
4419+
# built all,
4420+
# 2) the static expression contains a reference to a local (even if constant) variable
4421+
# (and is therefore not executable and raises this exception), in which
4422+
# case changing the constant, or the code affecting this constant, would lead to
4423+
# a different module hash anyway.
4424+
# In any case, we mark this Adjoint to have unresolvable static expressions.
4425+
# This will trigger a code generation step even if the module hash is unchanged.
4426+
adj.has_unresolved_static_expressions = True
4427+
4428+
return None # was: return self.generic_visit(node)
4429+
4430+
# Walk the tree, then apply replacements in one pass. ``adj.tree`` is
4431+
# always a Module, so we go straight into ``_walk_children``.
4432+
_walk_children(adj.tree)
4433+
for container, key, new_node in replacements:
4434+
if isinstance(container, list):
4435+
container[key] = new_node
4436+
else:
4437+
setattr(container, key, new_node)
43954438

43964439
# Evaluates a static expression that does not depend on runtime values
43974440
# if eval_types is True, try resolving the path using evaluated type information as well

warp/tests/test_codegen.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,61 @@ def test_extract_lambda_source_parenthesized_multiline_body(self):
14261426
self.assertIn("q[0] == 0.0", body)
14271427
ast.parse(f"def generated(q, qd):\n return {body}\n")
14281428

1429+
def test_replace_static_expressions_replaces_call_in_ast(self):
1430+
"""The walker actually mutates ``adj.tree``: every resolvable ``wp.static``
1431+
Call gets replaced with an ``ast.Constant`` (or ``ast.Name`` for a
1432+
Function result). This pins the deferred-replacement application step,
1433+
which is the only behavioural difference vs upstream's in-flight
1434+
replacement.
1435+
"""
1436+
from warp._src import codegen # noqa: PLC0415
1437+
1438+
_value_a = 7
1439+
_value_b = 13
1440+
1441+
def _kernel_with_two_statics(out: wp.array(dtype=int)):
1442+
i = wp.tid()
1443+
out[i] = wp.static(_value_a)
1444+
out[i] += wp.static(_value_b)
1445+
1446+
adj = codegen.Adjoint(_kernel_with_two_statics)
1447+
1448+
# No wp.static Calls should remain in the tree after replacement.
1449+
remaining_static_calls = [
1450+
node
1451+
for node in ast.walk(adj.tree)
1452+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == "static"
1453+
]
1454+
self.assertEqual(remaining_static_calls, [])
1455+
1456+
# Both constants should appear as ast.Constant nodes in the tree.
1457+
constants = {node.value for node in ast.walk(adj.tree) if isinstance(node, ast.Constant)}
1458+
self.assertIn(_value_a, constants)
1459+
self.assertIn(_value_b, constants)
1460+
1461+
def test_replace_static_expressions_defers_loop_var_reference(self):
1462+
"""A ``wp.static`` call inside a ``for`` body that references the loop
1463+
variable must be deferred — ``has_unresolved_static_expressions`` set,
1464+
Call left in the AST for codegen-time resolution. This pins the
1465+
loop-variable tracking in ``visit_For`` / ``visit_Call``.
1466+
"""
1467+
from warp._src import codegen # noqa: PLC0415
1468+
1469+
def _kernel_with_loop_var_static(out: wp.array(dtype=int)):
1470+
for i in range(10):
1471+
out[i] = wp.static(i + 1)
1472+
1473+
adj = codegen.Adjoint(_kernel_with_loop_var_static)
1474+
1475+
self.assertTrue(adj.has_unresolved_static_expressions)
1476+
# wp.static(i + 1) should still be a Call in the AST (not eagerly replaced).
1477+
remaining_static_calls = [
1478+
node
1479+
for node in ast.walk(adj.tree)
1480+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == "static"
1481+
]
1482+
self.assertEqual(len(remaining_static_calls), 1)
1483+
14291484

14301485
devices = get_test_devices()
14311486

0 commit comments

Comments
 (0)