Skip to content

[compile] Invoke split FX graph by codegen.#38657

Open
zhxchen17 wants to merge 1 commit intovllm-project:mainfrom
zhxchen17:zhxchen17/execution_code
Open

[compile] Invoke split FX graph by codegen.#38657
zhxchen17 wants to merge 1 commit intovllm-project:mainfrom
zhxchen17:zhxchen17/execution_code

Conversation

@zhxchen17
Copy link
Copy Markdown
Contributor

@zhxchen17 zhxchen17 commented Mar 31, 2026

Summary:

This PR reduces inference loop runtime overhead by codegen-ing slightly faster Python
code instead of invoking the FX graph directly after compilation.

Context:

Today VllmBackend returns a callable as a FX GraphModule with multiple submodules with the following code:

def forward(self, ...):
    self.submod_0(...)
    self.submod_1(...)
    ...

FX graph execution has some overhead due to:

  1. getattr() calls to fetch submodules.
  2. submodule calls will push multiple levels of CPython stack frame before getting to the real kernels.

We address this by introducing a new codegen layer after all compiler passes and right
before inference runtime. In this codegen layer we get full customizability over how the graph is executed.

Sample generated code:

    submod_0 = _submods[0](l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
    getitem = submod_0[0]
    getitem_1 = submod_0[1]
    getitem_2 = submod_0[2]
    getitem_3 = submod_0[3]
    getitem_4 = submod_0[4]
    submod_1 = _submods[1](getitem, s72, getitem_1, getitem_2, getitem_3)
    submod_2 = _submods[2](getitem_3, s72, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_0_modules_mlp_modules_gate_up_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
    getitem_5 = submod_2[0]
    getitem_6 = submod_2[1]
    getitem_7 = submod_2[2]
    getitem_8 = submod_2[3]
    getitem_9 = submod_2[4]
    submod_3 = _submods[3](getitem_5, s72, getitem_6, getitem_7, getitem_8)

This PR will reduce runtime overhead no matter VLLM_USE_AOT_COMPILE, VLLM_DISABLE_COMPILE_CACHE or VLLM_USE_MEGA_AOT_ARITFACT is enabled
or disabled. It will always be used in all paths.

In terms of caching, this PR will stores 2 extra pieces of data on disk:

  1. Python execution code.
  2. FQN of each submodule.

When VLLM_USE_AOT_COMPILE=1, these will be loaded and optionally used depending on whether
VLLM_USE_MEGA_ARTIFACT is enabled.

Based on the current change, it's possible to further reduce warm start time by skipping
graph module serialization. However to make the code review easier, we will do it in a separate
PR and this PR still helps with the runtime overhead in a self-contained way.

Test Result:
Benchmark script: https://github.com/zhxchen17/scripts/blob/main/vllm/overhead_bench.py

Benched on AMD EPYC 9654 96-Core Processor, using llama3-8b, config similar to pytorch/pytorch#177655

Before the change, we have ~20us gap between each submodule call:
Screenshot 2026-03-31 at 1 55 43 PM

After the change we reduced to 3us:
Screenshot 2026-03-31 at 1 27 17 PM

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: zhxchen17 zhxchen17@fb.com

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Summary:

This PR reduces inference loop runtime overhead by codegen-ing slightly faster Python
code instead of invoking the FX graph directly after compilation.

Context:

Today VllmBackend returns a callable as a FX GraphModule with multiple submodules with the following code:
```
def forward(self, ...):
    self.submod_0(...)
    self.submod_1(...)
    ...
```
FX graph execution has some overhead due to:
1. getattr() calls to fetch submodules.
2. submodule calls will push multiple levels of CPython stack frame before getting to the real kernels.

We address this by introducing a new codegen layer after all compiler passes and right
before inference runtime. In this codegen layer we get full customizability over how the graph is executed.

Sample generated code:
```
    submod_0 = _submods[0](l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
    getitem = submod_0[0]
    getitem_1 = submod_0[1]
    getitem_2 = submod_0[2]
    getitem_3 = submod_0[3]
    getitem_4 = submod_0[4]
    submod_1 = _submods[1](getitem, s72, getitem_1, getitem_2, getitem_3)
    submod_2 = _submods[2](getitem_3, s72, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_0_modules_mlp_modules_gate_up_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_)
    getitem_5 = submod_2[0]
    getitem_6 = submod_2[1]
    getitem_7 = submod_2[2]
    getitem_8 = submod_2[3]
    getitem_9 = submod_2[4]
    submod_3 = _submods[3](getitem_5, s72, getitem_6, getitem_7, getitem_8)
```

This PR will reduce runtime overhead no matter VLLM_USE_AOT_COMPILE, VLLM_DISABLE_COMPILE_CACHE or VLLM_USE_MEGA_AOT_ARITFACT is enabled
or disabled. It will always be used in all paths.

In terms of caching, this PR will stores 2 extra pieces of data on disk:
1. Python execution code.
2. FQN of each submodule.

When VLLM_USE_AOT_COMPILE=1, these will be loaded and optionally used depending on whether
VLLM_USE_MEGA_ARTIFACT is enabled.

Based on the current change, it's possible to further reduce warm start time by skipping
graph module serialization. However to make the code review easier, we will do it in a separate
PR and this PR still helps with the runtime overhead in a self-contained way.

Benchmark script: https://github.com/zhxchen17/scripts/blob/main/vllm/overhead_bench.py

Test Plan:
<TODO Images>

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a code generation mechanism to replace the interpreter-based execution of torch.fx.GraphModule with a plain Python function, aiming to reduce overhead from nn.Module.__call__ and attribute dispatch. The changes include a new codegen.py module and updates to the backend and caching logic to store and reconstruct the generated execution code. The review feedback highlights several critical robustness issues in the code generator, such as the lack of support for nested containers (lists, tuples, dicts), missing keyword argument handling in module calls, potential name collisions with internal variables, and syntax errors when handling empty parameter lists or unquoted indices.

Comment on lines +96 to +112
def _node_ref(arg: Any) -> str:
"""Convert an FX node argument to a source code reference."""
if isinstance(arg, torch.fx.Node):
return arg.name
return repr(arg)


def _format_output(args: Any) -> str:
"""Format the output node's args as a return expression."""
if isinstance(args, torch.fx.Node):
return args.name
if isinstance(args, (tuple, list)):
items = ", ".join(_node_ref(a) for a in args)
if isinstance(args, tuple):
return f"({items},)" if len(args) == 1 else f"({items})"
return f"[{items}]"
return repr(args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _node_ref and _format_output functions do not handle nested containers (lists, tuples, or dicts) that contain torch.fx.Node objects. Since _node_ref falls back to repr(arg) for non-Node types, it will produce strings like "[<Node ...>]" for a list of nodes, which is invalid Python code and will cause a SyntaxError or NameError when the generated code is executed. Additionally, _format_output lacks support for dictionary return values, which are valid in FX graphs.

Suggested change
def _node_ref(arg: Any) -> str:
"""Convert an FX node argument to a source code reference."""
if isinstance(arg, torch.fx.Node):
return arg.name
return repr(arg)
def _format_output(args: Any) -> str:
"""Format the output node's args as a return expression."""
if isinstance(args, torch.fx.Node):
return args.name
if isinstance(args, (tuple, list)):
items = ", ".join(_node_ref(a) for a in args)
if isinstance(args, tuple):
return f"({items},)" if len(args) == 1 else f"({items})"
return f"[{items}]"
return repr(args)
def _node_ref(arg: Any) -> str:
"""Convert an FX node argument to a source code reference recursively."""
if isinstance(arg, torch.fx.Node):
return arg.name
if isinstance(arg, list):
return f"[{', '.join(_node_ref(x) for x in arg)}]"
if isinstance(arg, tuple):
items = ", ".join(_node_ref(x) for x in arg)
return f"({items},)" if len(arg) == 1 else f"({items})"
if isinstance(arg, dict):
return "{" + ", ".join(f"{_node_ref(k)}: {_node_ref(v)}" for k, v in arg.items()) + "}"
return repr(arg)
def _format_output(args: Any) -> str:
"""Format the output node's args as a return expression."""
return _node_ref(args)

Comment on lines +45 to +65
elif node.op == "call_module":
target = node.target
if target not in submod_index:
submod_index[target] = len(submod_names)
submod_names.append(target)
idx = submod_index[target]
args_str = ", ".join(_node_ref(a) for a in node.args)
lines.append(f" {node.name} = _submods[{idx}]({args_str})")

elif node.op == "call_function" and node.target is operator.getitem:
source = _node_ref(node.args[0])
index = node.args[1]
lines.append(f" {node.name} = {source}[{index}]")

elif node.op == "output":
ret = _format_output(node.args[0])
lines.append(f" return {ret}")

params = ", ".join(param_names)
header = f"def execution_fn({params}, *, _submods):"
return "\n".join([header] + lines) + "\n", submod_names
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code generation logic has several robustness issues:

  1. Missing Kwargs: call_module nodes may have keyword arguments, which are currently ignored.
  2. Unquoted getitem Index: The index in operator.getitem is inserted directly into the f-string. If the index is a string (e.g., for dictionary unpacking), it will result in a NameError at runtime because it won't be quoted. Using _node_ref(index) (which uses repr()) fixes this.
  3. Name Collisions: The parameter name _submods could collide with placeholder names in the graph. Using a more unique name like __vllm_submods__ is safer.
  4. Empty Params Syntax: If the graph has no placeholders, the generated header def execution_fn(, *, _submods): would be a syntax error.
  5. Missing Imports: The generated code may reference torch objects (like dtypes or devices) via repr(), but torch is not imported in the generated function's scope.
        elif node.op == "call_module":
            target = node.target
            if target not in submod_index:
                submod_index[target] = len(submod_names)
                submod_names.append(target)
            idx = submod_index[target]
            args_str = ", ".join(_node_ref(a) for a in node.args)
            kwargs_str = ", ".join(f"{k}={_node_ref(v)}" for k, v in node.kwargs.items())
            all_args = ", ".join(filter(None, [args_str, kwargs_str]))
            lines.append(f"    {node.name} = __vllm_submods__[{idx}]({all_args})")

        elif node.op == "call_function" and node.target is operator.getitem:
            source = _node_ref(node.args[0])
            index = _node_ref(node.args[1])
            lines.append(f"    {node.name} = {source}[{index}]")

        elif node.op == "output":
            ret = _format_output(node.args[0])
            lines.append(f"    return {ret}")

    params_str = f"{', '.join(param_names)}, " if param_names else ""
    header = f"def execution_fn({params_str}*, __vllm_submods__):"
    return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names

c.forward if isinstance(c, torch.fx.GraphModule) else c
for c in (submod_callables[name] for name in submod_names)
]
return partial(fn, _submods=submods_list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This needs to be updated to match the safer internal parameter name used in the generated code to avoid collisions with graph node names.

Suggested change
return partial(fn, _submods=submods_list)
return partial(fn, __vllm_submods__=submods_list)

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall although Gemini comments seem real

Returns:
A callable that executes the stitching logic.
"""
ast.parse(code) # validate syntax
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this saved for inspection anywhere (cache/tlparse)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants