Skip to content

Commit d5e968e

Browse files
committed
tmp maybe_inplace
1 parent 7cf5683 commit d5e968e

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

vllm/compilation/passes/ir/inplace_raising.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class VllmIRInplaceRaisingPass(VllmInductorPass):
2828
The maybe_inplace overloads have the same signature as the default overload
2929
so the pass simply replaces the called overload.
3030
That makes the graph properly functional.
31+
32+
This pass operates pre-AOTAutograd,
33+
so it must handle non-normalized and non-functional IR.
3134
"""
3235

3336
def __init__(self, vllm_config: VllmConfig) -> None:
@@ -56,12 +59,40 @@ def __call__(self, graph: fx.Graph) -> None:
5659
# must have maybe_inplace overload and allow_inplace
5760
assert ir_op.allow_inplace and ir_op.maybe_inplace is not None
5861

62+
# Check that activation inputs are not used after this op
63+
for arg_idx in ir_op.activation_indices:
64+
arg = node.args[arg_idx]
65+
assert isinstance(arg, fx.Node), "Activation inputs must be fx.Node"
66+
for user in arg.users:
67+
if user is not node:
68+
# TODO only check topologically?
69+
logger.warning(
70+
"Node %s (input to %s) has another use", arg, node
71+
)
72+
# TODO raise error, this is undefined behavior, which should not be allowed.
73+
# Users can just use the default overload if they want to keep activation inputs untouched.
74+
75+
if arg.op == "placeholder":
76+
# This node represents a graph input, and maybe_inplace might modify it,
77+
# meaning the user does not care about it.
78+
# Mark it dirty so downstream passes know it can be modified without affecting correctness.
79+
# TODO should we store this in node.meta instead?
80+
arg.meta["custom"] = {
81+
"is_consumed": True,
82+
**arg.meta.get("custom", {}),
83+
}
84+
logger.debug(
85+
"vLLM IR op %s has an activation input that is a graph input",
86+
ir_op.name,
87+
)
88+
5989
# Same signature, just replace the overload that's called.
6090
node.target = ir_op.torch_op
91+
node.meta["custom"] = {"maybe_inplace": True, **node.meta.get("custom", {})}
6192
self.raised_ops[ir_op.name] += 1
6293

6394
count = sum(self.raised_ops.values())
6495
ops = ",".join(self.raised_ops.keys())
6596
logger.debug(
66-
"VllmIRLoweringPass raised %d vLLM IR nodes for op(s) %s", count, ops
97+
"VllmIRInplaceRaisingPass raised %d vLLM IR nodes for op(s) %s", count, ops
6798
)

vllm/compilation/passes/ir/lowering_pass.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import defaultdict
44
from collections.abc import Iterable
55

6+
import torch
67
from torch import fx
78
from torch._inductor.pattern_matcher import (
89
CallFunctionVarArgs,
@@ -17,6 +18,7 @@
1718
from vllm.logger import init_logger
1819
from vllm.logging_utils import lazy
1920

21+
from ..fx_utils import is_func
2022
from ..vllm_inductor_pass import VllmInductorPass
2123

2224
logger = init_logger(__name__)
@@ -138,3 +140,35 @@ def print_count(counts: dict[str, int]) -> str:
138140
if failed_nodes or failed_ops:
139141
logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops))
140142
logger.warning("Full node list: %s", failed_nodes)
143+
144+
145+
class CloneCleanupPass(VllmInductorPass):
146+
"""
147+
This pass removes clone nodes that are no longer needed after vLLM IR lowering.
148+
"""
149+
150+
def __init__(self, vllm_config: VllmConfig) -> None:
151+
super().__init__(vllm_config)
152+
153+
@VllmInductorPass.time_and_log
154+
def __call__(self, graph: fx.Graph) -> None:
155+
count = 0
156+
for node in graph.nodes:
157+
if "custom" in node.meta:
158+
logger.info(
159+
"Node %s with meta['custom']=%s, users: %s",
160+
node,
161+
node.meta["custom"],
162+
list(node.users),
163+
)
164+
165+
if not is_func(node, torch.ops.aten.clone.default):
166+
continue
167+
168+
logger.info("Node %s is a clone node, removing it", node)
169+
continue # TODO
170+
node.replace_all_uses_with(node.args[0])
171+
graph.erase_node(node)
172+
count += 1
173+
174+
logger.debug("CloneCleanupPass removed %d clone nodes", count)

vllm/compilation/passes/pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.platforms import current_platform
1515
from vllm.utils.system_utils import set_env_var
1616

17-
from .ir.lowering_pass import VllmIRLoweringPass
17+
from .ir.lowering_pass import CloneCleanupPass, VllmIRLoweringPass
1818
from .vllm_inductor_pass import VllmInductorPass
1919

2020
if rocm_aiter_ops.is_enabled():
@@ -109,6 +109,8 @@ def __call__(self, graph: fx.Graph) -> None:
109109
# DCE handles mutating ops correctly as well.
110110
self.ir_lowering(graph)
111111
VllmInductorPass.dump_prefix += 1
112+
self.clone_cleanup(graph)
113+
VllmInductorPass.dump_prefix += 1
112114

113115
# clean up after lowering again
114116
self.post_cleanup(graph)
@@ -161,6 +163,7 @@ def configure(self, config: VllmConfig) -> None:
161163
self.passes += [QKNormRoPEFusionPass(config)]
162164

163165
self.ir_lowering = VllmIRLoweringPass(config)
166+
self.clone_cleanup = CloneCleanupPass(config)
164167
self.post_cleanup = PostCleanupPass(config)
165168
self.fix_functionalization = FixFunctionalizationPass(config)
166169

@@ -182,6 +185,7 @@ def uuid(self) -> str:
182185

183186
passes.append(self.post_cleanup.uuid())
184187
passes.append(self.ir_lowering.uuid())
188+
passes.append(self.clone_cleanup.uuid())
185189
passes.append(self.post_cleanup.uuid())
186190
passes.append(self.fix_functionalization.uuid())
187191

vllm/compilation/passes/utility/noop_elimination.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def __call__(self, graph: torch.fx.Graph) -> None:
6969
count = 0
7070
# Remove no-op reshapes/views:
7171
for node in graph.nodes:
72+
if "custom" in node.meta:
73+
logger.info(
74+
"Node %s with meta['custom']=%s, users: %s",
75+
node,
76+
node.meta["custom"],
77+
list(node.users),
78+
)
79+
7280
if is_func(node, torch.ops.aten.reshape.default):
7381
# Case 1: rewrite reshape chains to reshapes on the base tensor
7482
input = node.args[0]

0 commit comments

Comments
 (0)