Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions tests/compile/h100/test_startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,92 @@ def test_model_startup(monkeypatch, vllm_runner, fresh_vllm_cache, spec):

# Warm start — compiled artifacts loaded from disk cache.
_check_model_run(vllm_runner, spec, is_cold_start=False)


# ---------------------------------------------------------------------------
# compile_model (compile-only) cold start tests
# ---------------------------------------------------------------------------

COMPILE_ONLY_SPECS = [
pytest.param(
ModelStartupSpec(
model="microsoft/Phi-tiny-MoE-instruct",
hf_overrides={},
cold_artifacts_saved=3,
warm_artifacts_saved=0,
warm_artifacts_loaded=3,
),
id="phi_tiny_moe",
),
pytest.param(
ModelStartupSpec(
model="openai/gpt-oss-120b",
hf_overrides={
"num_hidden_layers": 8,
"hidden_size": 256,
"intermediate_size": 512,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"num_local_experts": 8,
},
cold_artifacts_saved=3,
warm_artifacts_saved=0,
warm_artifacts_loaded=3,
),
id="gpt_oss_120b",
),
pytest.param(
ModelStartupSpec(
model="zai-org/GLM-4.5",
hf_overrides=_SMALL_MOE_OVERRIDES,
cold_artifacts_saved=4,
warm_artifacts_saved=0,
warm_artifacts_loaded=4,
),
id="glm_4.5",
),
]


def _compile_only_cold_start(spec: ModelStartupSpec):
"""Cold start using compile_model (fake weights, no GPU memory)."""
from vllm.compile_only import compile_model

old = compilation_counter.clone()
compile_model(
spec.model,
trust_remote_code=True,
max_model_len=256,
max_num_batched_tokens=1024,
block_size=64,
hf_overrides=spec.hf_overrides,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.NONE,
pass_config=PassConfig(fuse_allreduce_rms=False),
),
)

saved = (
compilation_counter.num_compiled_artifacts_saved
- old.num_compiled_artifacts_saved
)
print(f"\n=== COMPILE-ONLY COLD START for {spec.model} ===")
print(f" num_compiled_artifacts_saved={saved}")
assert saved == spec.cold_artifacts_saved, f"cold_artifacts_saved: got {saved}"


@pytest.mark.parametrize("spec", COMPILE_ONLY_SPECS)
@fork_new_process_for_each_test
def test_compile_only_startup(monkeypatch, vllm_runner, fresh_vllm_cache, spec):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

# Cold start: compile-only in a forked child (fork before CUDA init).
ctx = mp.get_context("fork")
p = ctx.Process(target=_compile_only_cold_start, args=(spec,))
p.start()
p.join()
assert p.exitcode == 0, "Compile-only cold start failed"

# Warm start — compiled artifacts loaded from disk cache.
_check_model_run(vllm_runner, spec, is_cold_start=False)
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import vllm.env_override # noqa: F401

MODULE_ATTRS = {
"compile_model": ".compile_only:compile_model",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
Expand All @@ -39,6 +40,7 @@
}

if typing.TYPE_CHECKING:
from vllm.compile_only import compile_model as compile_model
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
Expand Down
19 changes: 18 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
logger = init_logger(__name__)


class CompilationDone(Exception):
"""Raised in compile-only mode after compilation is complete.

This signals that the vLLM-compile cache has been populated and
there is no need to actually execute the compiled code.
"""

pass


def make_copy_and_call(
sym_tensor_indices: list[int],
input_buffers: list[torch.Tensor | None],
Expand Down Expand Up @@ -990,7 +1000,14 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
# Compute config/compiler/code hashes once and reuse
config_hash = vllm_config.compute_hash()
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
forward_code_files = list(sorted(self.compilation_config.traced_files))
# Filter out PyTorch internal files — they are already covered
# by the torch version in env_factors.
torch_root = os.path.dirname(torch.__file__) + os.sep
forward_code_files = [
f
for f in sorted(self.compilation_config.traced_files)
if not f.startswith(torch_root)
]

logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
Expand Down
14 changes: 14 additions & 0 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,23 @@ def patched_inline_call(self_: Any) -> Any:
# AOT artifact.
self.save_aot_compiled_function()

# In compile-only mode, raise CompilationDone after all
# piecewise graphs are compiled and cache artifacts saved.
# This is caught in gpu_worker.compile_or_warm_up_model()
# to skip execution with fake tensors.
if self.compilation_config.compile_only:
from .backends import CompilationDone

raise CompilationDone

with monitor_profiling_run():
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
# Same as above for non-AOT path.
if self.compilation_config.compile_only:
from .backends import CompilationDone

raise CompilationDone
with monitor_torch_compile(
self.vllm_config,
"torch.compile and initial profiling/warmup "
Expand Down
90 changes: 90 additions & 0 deletions vllm/compile_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compile-only mode: populate vLLM's torch.compile cache
without loading real model weights or allocating KV caches.

The compile-only flag causes the model loader to be wrapped with
FakeTensorMode (see ``fake_loader.wrap_loader_with_fake``), so the
user's original ``load_format`` is preserved and the real loader's
full pipeline runs — just with fake tensors instead of real weights.
"""

import argparse

from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)


def compile_model(
model: str,
*,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
quantization: str | None = None,
dtype: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> None:
"""Pre-populate vLLM's torch.compile cache for a model.

Runs compilation using fake weights (zero GPU memory)
so that vLLM's torch.compile cache is populated. Subsequent
``vllm serve`` or ``LLM(...)`` calls for the same model
configuration will hit the warm cache and skip compilation.

Args:
model: HuggingFace model name or path.
tensor_parallel_size: Number of tensor parallel GPUs.
pipeline_parallel_size: Number of pipeline parallel stages.
quantization: Quantization method (e.g. "fp8").
dtype: Model dtype.
trust_remote_code: Trust remote code from HuggingFace.
**kwargs: Additional arguments passed to ``EngineArgs``.
"""
from vllm.engine.arg_utils import EngineArgs

engine_args = EngineArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
quantization=quantization,
dtype=dtype,

Check failure on line 53 in vllm/compile_only.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "dtype" to "EngineArgs" has incompatible type "str"; expected "Literal['auto', 'half', 'float16', 'bfloat16', 'float', 'float32']" [arg-type]

Check failure on line 53 in vllm/compile_only.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "dtype" to "EngineArgs" has incompatible type "str"; expected "Literal['auto', 'half', 'float16', 'bfloat16', 'float', 'float32']" [arg-type]

Check failure on line 53 in vllm/compile_only.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "dtype" to "EngineArgs" has incompatible type "str"; expected "Literal['auto', 'half', 'float16', 'bfloat16', 'float', 'float32']" [arg-type]

Check failure on line 53 in vllm/compile_only.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "dtype" to "EngineArgs" has incompatible type "str"; expected "Literal['auto', 'half', 'float16', 'bfloat16', 'float', 'float32']" [arg-type]
trust_remote_code=trust_remote_code,
enforce_eager=False,
**kwargs,
)

vllm_config = engine_args.create_engine_config(usage_context=UsageContext.LLM_CLASS)
vllm_config.compilation_config.compile_only = True

_run_compile_with_config(vllm_config)


def run_compile_only(args: argparse.Namespace) -> None:
"""Run compile-only mode from CLI arguments."""
from vllm.engine.arg_utils import EngineArgs

engine_args = EngineArgs.from_cli_args(args)
engine_args.enforce_eager = False

vllm_config = engine_args.create_engine_config(usage_context=UsageContext.LLM_CLASS)
vllm_config.compilation_config.compile_only = True

_run_compile_with_config(vllm_config)


def _run_compile_with_config(vllm_config) -> None:
"""Shared compile-only logic."""
from vllm.plugins import load_general_plugins
from vllm.v1.executor import Executor

load_general_plugins()

executor_class = Executor.get_class(vllm_config)
executor = executor_class(vllm_config)
executor.collective_rpc("compile_or_warm_up_model")

logger.info("Compile-only mode complete. Cache populated.")
executor.shutdown()
8 changes: 8 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,12 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""

compile_only: bool = False
"""If True, run in compile-only mode: only torch.compile
compilation, skip CUDA graph capture, kernel warmup, and sampler
warmup. Used to pre-populate vLLM's torch.compile cache without
allocating KV caches or setting up the full engine."""

fast_moe_cold_start: bool | None = None
"""Optimization for fast MOE cold start.

Expand Down Expand Up @@ -739,6 +745,8 @@ def compute_hash(self) -> str:
"static_forward_context",
"pass_config", # handled separately below
"dynamic_shapes_config", # handled separately below
# compile_only doesn't affect the compiled graph
"compile_only",
}

from vllm.config.utils import get_hash_factors, hash_factors
Expand Down
51 changes: 51 additions & 0 deletions vllm/entrypoints/cli/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse

from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.utils.argparse_utils import FlexibleArgumentParser

DESCRIPTION = """[Experimental] Populate vLLM's torch.compile cache for a model.

This command is experimental and a work in progress. Not all models and
configurations are supported yet.

This runs compilation using fake weights (zero GPU memory) so that
vLLM's torch.compile cache is populated. Subsequent ``vllm serve`` or
``LLM(...)`` calls for the same model will hit the warm cache and skip
compilation.
"""


class CompileSubcommand(CLISubcommand):
"""The ``compile`` subcommand for the vLLM CLI."""

name = "compile"

@staticmethod
def cmd(args: argparse.Namespace) -> None:
from vllm.compile_only import run_compile_only

if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
run_compile_only(args)

def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
compile_parser = subparsers.add_parser(
self.name,
help="[Experimental] Populate vLLM's torch.compile cache for a model.",
description=DESCRIPTION,
usage="vllm compile [model_tag] [options]",
)
compile_parser = make_arg_parser(compile_parser)
compile_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
return compile_parser


def cmd_init() -> list[CLISubcommand]:
return [CompileSubcommand()]
2 changes: 2 additions & 0 deletions vllm/entrypoints/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.compile
import vllm.entrypoints.cli.launch
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
Expand All @@ -26,6 +27,7 @@ def main():
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.compile,
vllm.entrypoints.cli.launch,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
Expand Down
Loading
Loading