[compile] Invoke split FX graph by codegen.#38657
[compile] Invoke split FX graph by codegen.#38657zhxchen17 wants to merge 1 commit intovllm-project:mainfrom
Conversation
Summary:
This PR reduces inference loop runtime overhead by codegen-ing slightly faster Python
code instead of invoking the FX graph directly after compilation.
Context:
Today VllmBackend returns a callable as a FX GraphModule with multiple submodules with the following code:
```
def forward(self, ...):
self.submod_0(...)
self.submod_1(...)
...
```
FX graph execution has some overhead due to:
1. getattr() calls to fetch submodules.
2. submodule calls will push multiple levels of CPython stack frame before getting to the real kernels.
We address this by introducing a new codegen layer after all compiler passes and right
before inference runtime. In this codegen layer we get full customizability over how the graph is executed.
Sample generated code:
```
submod_0 = _submods[0](l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
getitem = submod_0[0]
getitem_1 = submod_0[1]
getitem_2 = submod_0[2]
getitem_3 = submod_0[3]
getitem_4 = submod_0[4]
submod_1 = _submods[1](getitem, s72, getitem_1, getitem_2, getitem_3)
submod_2 = _submods[2](getitem_3, s72, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_0_modules_mlp_modules_gate_up_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
getitem_5 = submod_2[0]
getitem_6 = submod_2[1]
getitem_7 = submod_2[2]
getitem_8 = submod_2[3]
getitem_9 = submod_2[4]
submod_3 = _submods[3](getitem_5, s72, getitem_6, getitem_7, getitem_8)
```
This PR will reduce runtime overhead no matter VLLM_USE_AOT_COMPILE, VLLM_DISABLE_COMPILE_CACHE or VLLM_USE_MEGA_AOT_ARITFACT is enabled
or disabled. It will always be used in all paths.
In terms of caching, this PR will stores 2 extra pieces of data on disk:
1. Python execution code.
2. FQN of each submodule.
When VLLM_USE_AOT_COMPILE=1, these will be loaded and optionally used depending on whether
VLLM_USE_MEGA_ARTIFACT is enabled.
Based on the current change, it's possible to further reduce warm start time by skipping
graph module serialization. However to make the code review easier, we will do it in a separate
PR and this PR still helps with the runtime overhead in a self-contained way.
Benchmark script: https://github.com/zhxchen17/scripts/blob/main/vllm/overhead_bench.py
Test Plan:
<TODO Images>
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a code generation mechanism to replace the interpreter-based execution of torch.fx.GraphModule with a plain Python function, aiming to reduce overhead from nn.Module.__call__ and attribute dispatch. The changes include a new codegen.py module and updates to the backend and caching logic to store and reconstruct the generated execution code. The review feedback highlights several critical robustness issues in the code generator, such as the lack of support for nested containers (lists, tuples, dicts), missing keyword argument handling in module calls, potential name collisions with internal variables, and syntax errors when handling empty parameter lists or unquoted indices.
| def _node_ref(arg: Any) -> str: | ||
| """Convert an FX node argument to a source code reference.""" | ||
| if isinstance(arg, torch.fx.Node): | ||
| return arg.name | ||
| return repr(arg) | ||
|
|
||
|
|
||
| def _format_output(args: Any) -> str: | ||
| """Format the output node's args as a return expression.""" | ||
| if isinstance(args, torch.fx.Node): | ||
| return args.name | ||
| if isinstance(args, (tuple, list)): | ||
| items = ", ".join(_node_ref(a) for a in args) | ||
| if isinstance(args, tuple): | ||
| return f"({items},)" if len(args) == 1 else f"({items})" | ||
| return f"[{items}]" | ||
| return repr(args) |
There was a problem hiding this comment.
The _node_ref and _format_output functions do not handle nested containers (lists, tuples, or dicts) that contain torch.fx.Node objects. Since _node_ref falls back to repr(arg) for non-Node types, it will produce strings like "[<Node ...>]" for a list of nodes, which is invalid Python code and will cause a SyntaxError or NameError when the generated code is executed. Additionally, _format_output lacks support for dictionary return values, which are valid in FX graphs.
| def _node_ref(arg: Any) -> str: | |
| """Convert an FX node argument to a source code reference.""" | |
| if isinstance(arg, torch.fx.Node): | |
| return arg.name | |
| return repr(arg) | |
| def _format_output(args: Any) -> str: | |
| """Format the output node's args as a return expression.""" | |
| if isinstance(args, torch.fx.Node): | |
| return args.name | |
| if isinstance(args, (tuple, list)): | |
| items = ", ".join(_node_ref(a) for a in args) | |
| if isinstance(args, tuple): | |
| return f"({items},)" if len(args) == 1 else f"({items})" | |
| return f"[{items}]" | |
| return repr(args) | |
| def _node_ref(arg: Any) -> str: | |
| """Convert an FX node argument to a source code reference recursively.""" | |
| if isinstance(arg, torch.fx.Node): | |
| return arg.name | |
| if isinstance(arg, list): | |
| return f"[{', '.join(_node_ref(x) for x in arg)}]" | |
| if isinstance(arg, tuple): | |
| items = ", ".join(_node_ref(x) for x in arg) | |
| return f"({items},)" if len(arg) == 1 else f"({items})" | |
| if isinstance(arg, dict): | |
| return "{" + ", ".join(f"{_node_ref(k)}: {_node_ref(v)}" for k, v in arg.items()) + "}" | |
| return repr(arg) | |
| def _format_output(args: Any) -> str: | |
| """Format the output node's args as a return expression.""" | |
| return _node_ref(args) |
| elif node.op == "call_module": | ||
| target = node.target | ||
| if target not in submod_index: | ||
| submod_index[target] = len(submod_names) | ||
| submod_names.append(target) | ||
| idx = submod_index[target] | ||
| args_str = ", ".join(_node_ref(a) for a in node.args) | ||
| lines.append(f" {node.name} = _submods[{idx}]({args_str})") | ||
|
|
||
| elif node.op == "call_function" and node.target is operator.getitem: | ||
| source = _node_ref(node.args[0]) | ||
| index = node.args[1] | ||
| lines.append(f" {node.name} = {source}[{index}]") | ||
|
|
||
| elif node.op == "output": | ||
| ret = _format_output(node.args[0]) | ||
| lines.append(f" return {ret}") | ||
|
|
||
| params = ", ".join(param_names) | ||
| header = f"def execution_fn({params}, *, _submods):" | ||
| return "\n".join([header] + lines) + "\n", submod_names |
There was a problem hiding this comment.
The code generation logic has several robustness issues:
- Missing Kwargs:
call_modulenodes may have keyword arguments, which are currently ignored. - Unquoted getitem Index: The index in
operator.getitemis inserted directly into the f-string. If the index is a string (e.g., for dictionary unpacking), it will result in aNameErrorat runtime because it won't be quoted. Using_node_ref(index)(which usesrepr()) fixes this. - Name Collisions: The parameter name
_submodscould collide with placeholder names in the graph. Using a more unique name like__vllm_submods__is safer. - Empty Params Syntax: If the graph has no placeholders, the generated header
def execution_fn(, *, _submods):would be a syntax error. - Missing Imports: The generated code may reference
torchobjects (like dtypes or devices) viarepr(), buttorchis not imported in the generated function's scope.
elif node.op == "call_module":
target = node.target
if target not in submod_index:
submod_index[target] = len(submod_names)
submod_names.append(target)
idx = submod_index[target]
args_str = ", ".join(_node_ref(a) for a in node.args)
kwargs_str = ", ".join(f"{k}={_node_ref(v)}" for k, v in node.kwargs.items())
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(f" {node.name} = __vllm_submods__[{idx}]({all_args})")
elif node.op == "call_function" and node.target is operator.getitem:
source = _node_ref(node.args[0])
index = _node_ref(node.args[1])
lines.append(f" {node.name} = {source}[{index}]")
elif node.op == "output":
ret = _format_output(node.args[0])
lines.append(f" return {ret}")
params_str = f"{', '.join(param_names)}, " if param_names else ""
header = f"def execution_fn({params_str}*, __vllm_submods__):"
return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names| c.forward if isinstance(c, torch.fx.GraphModule) else c | ||
| for c in (submod_callables[name] for name in submod_names) | ||
| ] | ||
| return partial(fn, _submods=submods_list) |
ProExpertProg
left a comment
There was a problem hiding this comment.
Looks good overall although Gemini comments seem real
| Returns: | ||
| A callable that executes the stitching logic. | ||
| """ | ||
| ast.parse(code) # validate syntax |
There was a problem hiding this comment.
Is this saved for inspection anywhere (cache/tlparse)?
Summary:
This PR reduces inference loop runtime overhead by codegen-ing slightly faster Python
code instead of invoking the FX graph directly after compilation.
Context:
Today VllmBackend returns a callable as a FX GraphModule with multiple submodules with the following code:
FX graph execution has some overhead due to:
We address this by introducing a new codegen layer after all compiler passes and right
before inference runtime. In this codegen layer we get full customizability over how the graph is executed.
Sample generated code:
This PR will reduce runtime overhead no matter VLLM_USE_AOT_COMPILE, VLLM_DISABLE_COMPILE_CACHE or VLLM_USE_MEGA_AOT_ARITFACT is enabled
or disabled. It will always be used in all paths.
In terms of caching, this PR will stores 2 extra pieces of data on disk:
When VLLM_USE_AOT_COMPILE=1, these will be loaded and optionally used depending on whether
VLLM_USE_MEGA_ARTIFACT is enabled.
Based on the current change, it's possible to further reduce warm start time by skipping
graph module serialization. However to make the code review easier, we will do it in a separate
PR and this PR still helps with the runtime overhead in a self-contained way.
Test Result:
Benchmark script: https://github.com/zhxchen17/scripts/blob/main/vllm/overhead_bench.py
Benched on AMD EPYC 9654 96-Core Processor, using llama3-8b, config similar to pytorch/pytorch#177655
Before the change, we have ~20us gap between each submodule call:

After the change we reduced to 3us:

Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: zhxchen17 zhxchen17@fb.com
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.