Skip to content

Commit 3975eb6

Browse files
authored
Revert "[Startup] Parallelize torch/transformers import + weight prefetch + forkserver prewarm" (vllm-project#40438)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
1 parent 5a94a19 commit 3975eb6

3 files changed

Lines changed: 6 additions & 220 deletions

File tree

vllm/entrypoints/cli/main.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,80 +5,10 @@
55
Note that all future modules must be lazily loaded within main
66
to avoid certain eager import breakage."""
77

8-
import contextlib
98
import importlib.metadata
10-
import os
119
import sys
12-
import threading as _threading
1310

14-
15-
# [startup] Kick off torch + transformers .so/module loading in a background
16-
# thread before we touch vllm.logger (which pulls vllm/__init__.py ->
17-
# vllm.env_override -> `import torch` on the main thread). Python import
18-
# lock serializes the same-module import across threads, but the .so dlopen
19-
# inside torch's init releases the GIL during file I/O. Main thread's
20-
# non-torch imports (vllm.envs submodules, stdlib, fastapi, etc.) can make
21-
# progress on the CPU while the background thread pays the ~2 s of cuda
22-
# .so loading. `import transformers` is also ~2 s of cold-disk work and
23-
# depends on torch; chain it after torch in the same thread so subsequent
24-
# `from transformers import ...` lines on the main thread hit a warm
25-
# module cache.
26-
def _bg_preload_torch() -> None:
27-
try:
28-
import torch # noqa: F401
29-
except Exception:
30-
return
31-
with contextlib.suppress(Exception):
32-
import transformers # noqa: F401
33-
34-
35-
_threading.Thread(
36-
target=_bg_preload_torch, daemon=True, name="vllm-torch-preload"
37-
).start()
38-
39-
40-
# [startup] Pre-spawn EngineCore via forkserver preload, in a background
41-
# thread. Only fires for `vllm serve` (the only subcommand that spawns a
42-
# long-running EngineCore). The forkserver process is forked once and
43-
# preloaded with vllm.v1.engine.async_llm (~3-5 s of imports). When
44-
# AsyncLLM.from_vllm_config later runs, Process.start() forks from the
45-
# already-warm forkserver instead of paying spawn() cost (~5 s in child
46-
# for fresh Python + imports).
47-
#
48-
# Kicking the preload in a BG thread lets the ~3-5 s ensure_running cost
49-
# overlap with APIServer's argparse + config resolution (~5-10 s on cold
50-
# disk). Default cli_env_setup sets spawn; we override to forkserver
51-
# before that runs so the path is consistent.
52-
def _bg_prewarm_forkserver() -> None:
53-
try:
54-
import multiprocessing
55-
import multiprocessing.forkserver as forkserver
56-
57-
# set_start_method MUST be called before ensure_running. It also
58-
# can only be called once per process; any later override by
59-
# vllm's build_async_engine_client will just see the existing
60-
# setting.
61-
multiprocessing.set_start_method("forkserver", force=False)
62-
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
63-
forkserver.ensure_running()
64-
except Exception:
65-
pass
66-
67-
68-
if len(sys.argv) > 1 and sys.argv[1] == "serve":
69-
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "forkserver")
70-
# daemon=True so early CLI exits (bad args, --help, import errors)
71-
# don't hang waiting for ensure_running(). The forkserver subprocess
72-
# itself is tracked by module-level state in multiprocessing.forkserver
73-
# and survives this thread exiting; subsequent spawn() calls reuse it.
74-
_threading.Thread(
75-
target=_bg_prewarm_forkserver,
76-
daemon=True,
77-
name="vllm-forkserver-prewarm",
78-
).start()
79-
80-
81-
from vllm.logger import init_logger # noqa: E402
11+
from vllm.logger import init_logger
8212

8313
logger = init_logger(__name__)
8414

vllm/entrypoints/openai/api_server.py

Lines changed: 2 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import warnings
1313
from argparse import Namespace
1414
from collections.abc import AsyncIterator
15-
from contextlib import asynccontextmanager, suppress
15+
from contextlib import asynccontextmanager
1616
from typing import Any
1717

1818
import uvloop
@@ -74,121 +74,6 @@
7474
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)
7575

7676

77-
def _startup_prefetch_weights(vllm_config: "VllmConfig") -> None:
78-
"""Kick off reading model weight shards into OS page cache from the
79-
parent APIServer. EngineCore will read the same files a few seconds
80-
later from the child; by then the kernel already has them ready.
81-
82-
All work (directory resolution, HF/ModelScope cache lookup, globbing,
83-
and the reads themselves) runs inside the background thread so we do
84-
not block the asyncio event loop.
85-
86-
Best-effort: any failure (unknown model location, permission, etc.) is
87-
swallowed — vLLM's existing in-child prefetch then runs normally.
88-
"""
89-
import threading
90-
91-
# Capture only the small scalar fields the thread needs. Avoid holding
92-
# a reference to vllm_config (which contains unpicklable objects) for
93-
# longer than necessary.
94-
model_ref = vllm_config.model_config.model
95-
revision = vllm_config.model_config.revision
96-
download_dir = vllm_config.load_config.download_dir
97-
98-
def _prefetch_worker() -> None:
99-
import glob
100-
import os
101-
102-
from vllm import envs
103-
104-
candidate_dir: str | None = None
105-
106-
# 1. Local path?
107-
if os.path.isdir(model_ref):
108-
candidate_dir = model_ref
109-
else:
110-
# 2. HF / ModelScope repo id — resolve to the local cache
111-
# snapshot dir using the same revision / cache_dir the weight
112-
# loader will use, so we prefetch the right files.
113-
try:
114-
if envs.VLLM_USE_MODELSCOPE:
115-
from modelscope.hub.snapshot_download import (
116-
snapshot_download,
117-
)
118-
119-
candidate_dir = snapshot_download(
120-
model_id=model_ref,
121-
revision=revision,
122-
cache_dir=download_dir,
123-
local_files_only=True,
124-
)
125-
else:
126-
from huggingface_hub import snapshot_download
127-
128-
candidate_dir = snapshot_download(
129-
repo_id=model_ref,
130-
revision=revision,
131-
cache_dir=download_dir,
132-
allow_patterns=[
133-
"*.safetensors",
134-
"*.bin",
135-
"*.json",
136-
"*tokenizer*",
137-
],
138-
local_files_only=True,
139-
)
140-
except Exception:
141-
return # not cached yet or not a known repo id
142-
143-
if not candidate_dir or not os.path.isdir(candidate_dir):
144-
return
145-
146-
# Weight shards: large files, read into page cache.
147-
shard_paths = sorted(
148-
glob.glob(os.path.join(candidate_dir, "*.safetensors"))
149-
+ glob.glob(os.path.join(candidate_dir, "*.bin"))
150-
)
151-
# Tokenizer/config sidecars: small, but re-opened in the child and
152-
# add synchronous open+read latency when the disk is cold.
153-
sidecar_paths = sorted(
154-
glob.glob(os.path.join(candidate_dir, "*.json"))
155-
+ glob.glob(os.path.join(candidate_dir, "tokenizer.model"))
156-
+ glob.glob(os.path.join(candidate_dir, "*tokenizer*"))
157-
)
158-
shard_paths.extend(sidecar_paths)
159-
if not shard_paths:
160-
return
161-
162-
logger.debug(
163-
"Parent-side weight prefetch starting for %d files in %s",
164-
len(shard_paths),
165-
candidate_dir,
166-
)
167-
168-
# Match vLLM's in-child prefetch block size + thread count.
169-
block_size = 16 * 1024 * 1024 # 16 MB
170-
# Read shards in parallel across 8 worker threads (bounded) to
171-
# saturate multi-spindle / multi-queue storage without thrashing.
172-
from concurrent.futures import ThreadPoolExecutor
173-
174-
def read_one(p: str) -> None:
175-
try:
176-
with open(p, "rb") as f:
177-
while f.read(block_size):
178-
pass
179-
except Exception:
180-
pass
181-
182-
with ThreadPoolExecutor(max_workers=8) as pool:
183-
list(pool.map(read_one, shard_paths))
184-
185-
threading.Thread(
186-
target=_prefetch_worker,
187-
daemon=True,
188-
name="vllm-parent-weight-prefetch",
189-
).start()
190-
191-
19277
@asynccontextmanager
19378
async def build_async_engine_client(
19479
args: Namespace,
@@ -200,10 +85,7 @@ async def build_async_engine_client(
20085
# The executor is expected to be mp.
20186
# Pre-import heavy modules in the forkserver process
20287
logger.debug("Setup forkserver with pre-imports")
203-
# May already have been set by the CLI entry's async prewarm
204-
# (vllm/entrypoints/cli/main.py); tolerate re-call.
205-
with suppress(RuntimeError):
206-
multiprocessing.set_start_method("forkserver", force=False)
88+
multiprocessing.set_start_method("forkserver")
20789
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
20890
forkserver.ensure_running()
20991
logger.debug("Forkserver setup complete!")
@@ -241,28 +123,6 @@ async def build_async_engine_client_from_engine_args(
241123
# Create the EngineConfig (determines if we can use V1).
242124
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
243125

244-
# [startup] Start prefetching model weight shards into the OS page cache
245-
# in a background thread from the PARENT APIServer process. EngineCore
246-
# will page-fault on these same files ~10-15 s later (after fork + CUDA
247-
# context + distributed init + model init). For large-weight cases
248-
# (tens of GB) this parent-side head start meaningfully shrinks the
249-
# prefetch+load phase that the engine's in-child prefetch otherwise
250-
# barely overlaps.
251-
#
252-
# Skip in API-only workers that connect to an already-running EngineCore
253-
# (multi-API-server / disaggregated setups): those processes never load
254-
# weights, and if we prefetched from all of them we'd contend with the
255-
# engine's own read. Presence of an `input_address` in client_config is
256-
# the current marker that this worker is headless.
257-
#
258-
# Best-effort: if the model is a local path, glob for safetensors; if
259-
# it's a repo-id, try to resolve via HF hub (or ModelScope) local cache.
260-
# Any failure silently falls through to the existing in-child prefetch
261-
# path. All I/O (incl. directory resolution) runs inside the BG thread
262-
# so the asyncio event loop is never blocked.
263-
if not (client_config and client_config.get("input_address")):
264-
_startup_prefetch_weights(vllm_config)
265-
266126
from vllm.v1.engine.async_llm import AsyncLLM
267127

268128
async_llm: AsyncLLM | None = None

vllm/envs.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
6363
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
6464
VLLM_XLA_USE_SPMD: bool = False
65-
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn", "forkserver"] = "fork"
65+
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
6666
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
6767
VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False
6868
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@@ -765,13 +765,9 @@ def _get_or_set_default() -> str:
765765
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "0"))
766766
),
767767
# Use dedicated multiprocess context for workers.
768-
# spawn, fork, and forkserver all work. forkserver is opt-in for fast
769-
# startup when paired with the CLI's async prewarm (see
770-
# vllm/entrypoints/cli/main.py) — the forkserver process is preloaded
771-
# with vllm.v1.engine.async_llm and a subsequent EngineCore Process.start()
772-
# forks from that warm sibling instead of paying spawn cost.
768+
# Both spawn and fork work
773769
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(
774-
"VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork", "forkserver"]
770+
"VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"]
775771
),
776772
# Path to the cache for storing downloaded assets
777773
"VLLM_ASSETS_CACHE": lambda: os.path.expanduser(

0 commit comments

Comments
 (0)