Skip to content

Commit 3ae46b4

Browse files
committed
[compile] Invoke split FX graph by codegen.
Summary: This PR reduces inference loop runtime overhead by codegen-ing slightly faster Python code instead of invoking the FX graph directly after compilation. Context: Today VllmBackend returns a callable as a FX GraphModule with multiple submodules with the following code: ``` def forward(self, ...): self.submod_0(...) self.submod_1(...) ... ``` FX graph execution has some overhead due to: 1. getattr() calls to fetch submodules. 2. submodule calls will push multiple levels of CPython stack frame before getting to the real kernels. We address this by introducing a new codegen layer after all compiler passes and right before inference runtime. In this codegen layer we get full customizability over how the graph is executed. Sample generated code: ``` submod_0 = _submods[0](l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_) getitem = submod_0[0] getitem_1 = submod_0[1] getitem_2 = submod_0[2] getitem_3 = submod_0[3] getitem_4 = submod_0[4] submod_1 = _submods[1](getitem, s72, getitem_1, getitem_2, getitem_3) submod_2 = _submods[2](getitem_3, s72, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_0_modules_mlp_modules_gate_up_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_) getitem_5 = submod_2[0] getitem_6 = submod_2[1] getitem_7 = submod_2[2] getitem_8 = submod_2[3] getitem_9 = submod_2[4] submod_3 = _submods[3](getitem_5, s72, getitem_6, getitem_7, getitem_8) ``` This PR will reduce runtime overhead no matter VLLM_USE_AOT_COMPILE, VLLM_DISABLE_COMPILE_CACHE or VLLM_USE_MEGA_AOT_ARITFACT is enabled or disabled. It will always be used in all paths. In terms of caching, this PR will stores 2 extra pieces of data on disk: 1. Python execution code. 2. FQN of each submodule. When VLLM_USE_AOT_COMPILE=1, these will be loaded and optionally used depending on whether VLLM_USE_MEGA_ARTIFACT is enabled. Based on the current change, it's possible to further reduce warm start time by skipping graph module serialization. However to make the code review easier, we will do it in a separate PR and this PR still helps with the runtime overhead in a self-contained way. Benchmark script: https://github.com/zhxchen17/scripts/blob/main/vllm/overhead_bench.py Test Plan: <TODO Images> Reviewers: Subscribers: Tasks: Tags: Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent 6183cae commit 3ae46b4

File tree

3 files changed

+182
-9
lines changed

3 files changed

+182
-9
lines changed

vllm/compilation/backends.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,23 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
12341234
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
12351235
)
12361236

1237+
from vllm.compilation.codegen import (
1238+
compile_execution_fn,
1239+
generate_execution_code,
1240+
)
1241+
1242+
execution_code, submod_names = generate_execution_code(self.split_gm)
1243+
# Use getattr to get correct callables: __dict__ has PiecewiseBackend
1244+
# instances (from PiecewiseCompileInterpreter), _modules has originals.
1245+
# getattr checks __dict__ first, then falls back to _modules.
1246+
submod_callables = {
1247+
name: getattr(self.split_gm, name)
1248+
for name, _ in self.split_gm.named_children()
1249+
}
1250+
runtime_callable = compile_execution_fn(
1251+
execution_code, submod_callables, submod_names
1252+
)
1253+
12371254
if (
12381255
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
12391256
or not self.compilation_config.cudagraph_copy_inputs
@@ -1242,9 +1259,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
12421259
graph_to_serialize,
12431260
example_inputs,
12441261
self.prefix,
1245-
self.split_gm,
1262+
runtime_callable,
12461263
is_encoder=self.is_encoder,
12471264
vllm_backend=self,
1265+
execution_code=execution_code,
1266+
submod_names=submod_names,
12481267
)
12491268

12501269
# index of tensors that have symbolic shapes (batch size)
@@ -1265,7 +1284,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
12651284
copy_and_call = make_copy_and_call(
12661285
sym_tensor_indices,
12671286
[example_inputs[x].clone() for x in sym_tensor_indices],
1268-
self.split_gm,
1287+
runtime_callable,
12691288
)
12701289

12711290
return VllmSerializableFunction(
@@ -1276,4 +1295,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
12761295
is_encoder=self.is_encoder,
12771296
vllm_backend=self,
12781297
sym_tensor_indices=sym_tensor_indices,
1298+
execution_code=execution_code,
1299+
submod_names=submod_names,
12791300
)

vllm/compilation/caching.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def __init__(
184184
vllm_backend: Any | None = None,
185185
sym_tensor_indices: list[int] | None = None,
186186
aot_autograd_config: dict[str, Any] | None = None,
187+
execution_code: str | None = None,
188+
submod_names: list[str] | None = None,
187189
) -> None:
188190
assert isinstance(graph_module, torch.fx.GraphModule)
189191
self.graph_module = graph_module
@@ -194,6 +196,8 @@ def __init__(
194196
self.shape_env = None
195197
self.vllm_backend = vllm_backend
196198
self.sym_tensor_indices = sym_tensor_indices
199+
self.execution_code = execution_code
200+
self.submod_names = submod_names
197201
self._fake_mode: Any | None = None
198202

199203
import torch._functorch.config as functorch_config
@@ -453,7 +457,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
453457

454458
standalone_compile_artifacts.load_all()
455459

456-
submod_names = standalone_compile_artifacts.submodule_names()
460+
piecewise_submod_names = standalone_compile_artifacts.submodule_names()
457461
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
458462

459463
for cache_key in standalone_compile_artifacts.submodule_bytes:
@@ -473,13 +477,13 @@ def reconstruct_serializable_fn_from_mega_artifact(
473477

474478
# spot check that cached submodules exist in the graph structure
475479
graph_children = {name for name, _ in split_gm.named_children()}
476-
missing = set(submod_names) - graph_children
480+
missing = set(piecewise_submod_names) - graph_children
477481
assert not missing, (
478482
f"artifacts reference submodules not in graph: {missing}. "
479483
f"graph has: {sorted(graph_children)}"
480484
)
481485

482-
for i, submod_name in enumerate(submod_names):
486+
for i, submod_name in enumerate(piecewise_submod_names):
483487
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
484488

485489
sym_shape_indices = sym_shape_indices_map[submod_name]
@@ -490,15 +494,15 @@ def reconstruct_serializable_fn_from_mega_artifact(
490494
graph=None, # not needed for cached artifacts
491495
vllm_config=vllm_config,
492496
piecewise_compile_index=i,
493-
total_piecewise_compiles=len(submod_names),
497+
total_piecewise_compiles=len(piecewise_submod_names),
494498
sym_shape_indices=sym_shape_indices,
495499
vllm_backend=vllm_backend,
496500
returns_tuple=returns_tuple,
497501
compiled_runnables=runnables,
498502
)
499503

500504
is_first = i == 0
501-
is_last = i == len(submod_names) - 1
505+
is_last = i == len(piecewise_submod_names) - 1
502506
wrapped_backend = wrap_with_cudagraph_if_needed(
503507
piecewise_backend,
504508
vllm_config,
@@ -513,6 +517,21 @@ def reconstruct_serializable_fn_from_mega_artifact(
513517
submod_name,
514518
)
515519

520+
# Use codegen'd execution code if available, fall back to split_gm
521+
execution_code = state.get("execution_code")
522+
submod_names = state.get("submod_names")
523+
if execution_code is not None and submod_names is not None:
524+
from vllm.compilation.codegen import compile_execution_fn
525+
526+
submod_callables = {
527+
name: getattr(split_gm, name) for name, _ in split_gm.named_children()
528+
}
529+
runtime_callable = compile_execution_fn(
530+
execution_code, submod_callables, submod_names
531+
)
532+
else:
533+
runtime_callable = split_gm
534+
516535
if compilation_config.cudagraph_copy_inputs:
517536
sym_tensor_indices = state["sym_tensor_indices"]
518537
input_buffers = [
@@ -521,9 +540,11 @@ def reconstruct_serializable_fn_from_mega_artifact(
521540
)
522541
for idx in sym_tensor_indices
523542
]
524-
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
543+
optimized_call = make_copy_and_call(
544+
sym_tensor_indices, input_buffers, runtime_callable
545+
)
525546
else:
526-
optimized_call = split_gm
547+
optimized_call = runtime_callable
527548

528549
fn = VllmSerializableFunction(
529550
**state,

vllm/compilation/codegen.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Code generation for split_gm stitching graph execution.
4+
5+
Generates a plain Python function that replaces the FX GraphModule's
6+
interpreter-based execution of the stitching graph, eliminating
7+
nn.Module.__call__ overhead and __getattr__ dispatch.
8+
"""
9+
10+
import operator
11+
from collections.abc import Callable
12+
from functools import partial
13+
from typing import Any
14+
15+
import torch.fx
16+
from torch._dynamo.utils import dynamo_timed
17+
from torch._logging import trace_structured
18+
19+
20+
@dynamo_timed("vllm.generate_execution_code")
21+
def generate_execution_code(
22+
split_gm: torch.fx.GraphModule,
23+
) -> tuple[str, list[str]]:
24+
"""Generate Python source code from a split_gm's stitching graph.
25+
26+
Walks split_gm.graph.nodes and produces a function that calls
27+
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
28+
and dict lookup cost.
29+
30+
Args:
31+
split_gm: The split graph module produced by split_graph().
32+
33+
Returns:
34+
A tuple of (code, submod_names) where code is the Python source
35+
and submod_names is the ordered list of submodule target names
36+
corresponding to list indices used in the generated code.
37+
"""
38+
lines: list[str] = []
39+
param_names: list[str] = []
40+
submod_names: list[str] = []
41+
submod_index: dict[str, int] = {}
42+
43+
for node in split_gm.graph.nodes:
44+
if node.op == "placeholder":
45+
param_names.append(node.name)
46+
47+
elif node.op == "call_module":
48+
target = node.target
49+
if target not in submod_index:
50+
submod_index[target] = len(submod_names)
51+
submod_names.append(target)
52+
idx = submod_index[target]
53+
args_str = ", ".join(_node_ref(a) for a in node.args)
54+
kwargs_str = ", ".join(
55+
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
56+
)
57+
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
58+
lines.append(f" {node.name} = __vllm_submods__[{idx}]({all_args})")
59+
60+
elif node.op == "call_function" and node.target is operator.getitem:
61+
source = _node_ref(node.args[0])
62+
index = node.args[1]
63+
assert isinstance(index, int)
64+
lines.append(f" {node.name} = {source}[{index}]")
65+
66+
elif node.op == "output":
67+
assert len(node.args) == 1
68+
ret = _node_ref(node.args[0])
69+
lines.append(f" return {ret}")
70+
71+
else:
72+
raise RuntimeError(f"Unsupported node from codegen: {node.format_node()}")
73+
74+
assert len(param_names) > 0
75+
params = ", ".join(param_names)
76+
header = f"def execution_fn({params}, *, __vllm_submods__):"
77+
return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names
78+
79+
80+
@dynamo_timed("vllm.compile_execution_fn")
81+
def compile_execution_fn(
82+
code: str,
83+
submod_callables: dict[str, Callable[..., Any]],
84+
submod_names: list[str],
85+
) -> Callable[..., Any]:
86+
"""Compile execution code and bind submodule callables.
87+
88+
Args:
89+
code: Python source from generate_execution_code().
90+
submod_callables: Mapping of submodule names to their callables.
91+
submod_names: Ordered list of submodule names matching the indices
92+
used in the generated code.
93+
94+
Returns:
95+
A callable that executes the stitching logic.
96+
"""
97+
trace_structured(
98+
"artifact",
99+
metadata_fn=lambda: {
100+
"name": "vllm_execution_code",
101+
"encoding": "string",
102+
},
103+
payload_fn=lambda: code,
104+
)
105+
namespace: dict[str, Any] = {}
106+
exec(code, namespace) # noqa: S102
107+
fn = namespace["execution_fn"]
108+
# Use .forward() directly to avoid nn.Module.__call__ overhead.
109+
submods_list = [
110+
c.forward if isinstance(c, torch.fx.GraphModule) else c
111+
for c in (submod_callables[name] for name in submod_names)
112+
]
113+
return partial(fn, __vllm_submods__=submods_list)
114+
115+
116+
def _node_ref(arg: Any) -> str:
117+
"""Convert an FX node argument to a source code reference recursively."""
118+
if isinstance(arg, torch.fx.Node):
119+
return arg.name
120+
if isinstance(arg, list):
121+
return f"[{', '.join(_node_ref(x) for x in arg)}]"
122+
if isinstance(arg, tuple):
123+
items = ", ".join(_node_ref(x) for x in arg)
124+
return f"({items},)" if len(arg) == 1 else f"({items})"
125+
if isinstance(arg, dict):
126+
return (
127+
"{"
128+
+ ", ".join(f"{_node_ref(k)}: {_node_ref(v)}" for k, v in arg.items())
129+
+ "}"
130+
)
131+
return repr(arg)

0 commit comments

Comments
 (0)