Skip to content

Commit fedc0b7

Browse files
authored
[chore] Upgrade vllm to 0.23.0 (NovaSky-AI#1800)
Signed-off-by: SumanthRH <sumanthrh99@gmail.com> Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
1 parent 34f6da9 commit fedc0b7

10 files changed

Lines changed: 1402 additions & 952 deletions

File tree

pyproject.toml

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies = [
1919
"tokenizers>=0.21.2",
2020
"transformers>=5.6.1,<=5.8.0",
2121
"typer>=0.17.4",
22-
# "wandb>=0.22.0",
2322
"peft==0.18.1",
2423
"hf_transfer",
2524
"cloudpathlib>=0.23.0",
@@ -106,7 +105,7 @@ skyrl-train = [
106105

107106
fsdp = [
108107
"skyrl[skyrl-train]",
109-
"vllm==0.20.2; sys_platform == 'linux'",
108+
"vllm==0.23.0; sys_platform == 'linux'",
110109
"vllm-router; sys_platform == 'linux'",
111110
# The `nixl` shim provides that namespace and dispatches to `nixl_cu12`.
112111
# `nixl-cu12` ships the `nixl_cu12` module, but vLLM imports `nixl._api`.
@@ -117,9 +116,9 @@ fsdp = [
117116
"causal-conv1d; sys_platform == 'linux'",
118117
"flash-attn==2.8.3; sys_platform == 'linux'",
119118
"torch==2.11.0; sys_platform == 'linux'",
120-
"flashinfer-python==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
121-
"flashinfer-jit-cache==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
122-
"flashinfer-cubin==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
119+
"flashinfer-python==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
120+
"flashinfer-jit-cache==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
121+
"flashinfer-cubin==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
123122
"torchvision; sys_platform == 'linux'",
124123
]
125124

@@ -130,22 +129,22 @@ megatron = [
130129
"flash-linear-attention; sys_platform == 'linux'",
131130
"causal-conv1d; sys_platform == 'linux'",
132131
"mamba-ssm>=2.3.0; sys_platform == 'linux'",
133-
"vllm==0.20.2; sys_platform == 'linux'",
132+
"vllm==0.23.0; sys_platform == 'linux'",
134133
"vllm-router; sys_platform == 'linux'",
135134
# The `nixl` shim provides that namespace and dispatches to `nixl_cu12`.
136135
# `nixl-cu12` ships the `nixl_cu12` module, but vLLM imports `nixl._api`.
137136
# Its metadata hard-depends on `nixl-cu13` too; that variant is overridden
138137
# out below (it would drag in the CUDA-13 stack and break the cu12 torch pin).
139138
"nixl; sys_platform == 'linux'",
140139
"torch==2.11.0; sys_platform == 'linux'",
141-
"flashinfer-python==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
140+
"flashinfer-python==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
142141
"torchvision; sys_platform == 'linux'",
143142
# megatron-bridge requires Python 3.12+; pin megatron-core to the same
144143
# constraint so both packages are consistently available (or absent).
145144
"megatron-bridge; sys_platform == 'linux' and python_version >= '3.12'",
146145
"megatron-core; sys_platform == 'linux' and python_version >= '3.12'",
147-
"flashinfer-jit-cache==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
148-
"flashinfer-cubin==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
146+
"flashinfer-jit-cache==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
147+
"flashinfer-cubin==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
149148
"nvidia-modelopt; sys_platform == 'linux'",
150149
]
151150

@@ -201,18 +200,6 @@ required-environments = [
201200
"sys_platform == 'darwin' and platform_machine == 'arm64'",
202201
]
203202

204-
constraint-dependencies = [
205-
"flashinfer-jit-cache==0.6.8.post1",
206-
"flashinfer-cubin==0.6.8.post1",
207-
# fastapi 0.137.0 refactored include_router() to store `_IncludedRouter` wrapper objects in
208-
# `app.routes`, which prometheus-fastapi-instrumentator (pulled in transitively by vLLM) cannot
209-
# handle: `_get_route_name` accesses `route.path` and raises
210-
# `AttributeError: '_IncludedRouter' object has no attribute 'path'`, so the vLLM server's /health
211-
# endpoint 500s and the server never becomes healthy. Cap below 0.137 until the instrumentator is
212-
# fixed. See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/370 and
213-
# https://github.com/vllm-project/vllm/issues/45596
214-
"fastapi<0.137",
215-
]
216203
# each backend should have separate dependencies that can potentially clash
217204
# megatron also clashes with the jax dependency from gpu and tpu extras
218205
conflicts = [
@@ -247,7 +234,12 @@ override-dependencies = [
247234
"transformer-engine-cu13; sys_platform == 'never'",
248235
# `nixl` hard-depends on both nixl-cu12 and nixl-cu13; drop the cu13 variant
249236
# so it doesn't pull the CUDA-13 stack and bump torch off the cu12 pin.
250-
"nixl-cu13; sys_platform == 'never'"
237+
"nixl-cu13; sys_platform == 'never'",
238+
# Megatron-Bridge pins flashinfer-python==0.6.8.post1, which conflicts with
239+
# our exact 0.6.12 pin (the version vLLM 0.23.0 requires). Override it to our version.
240+
"flashinfer-python==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
241+
"flashinfer-jit-cache==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
242+
"flashinfer-cubin==0.6.12; sys_platform == 'linux' and platform_machine == 'x86_64'",
251243
]
252244

253245
[tool.uv.extra-build-dependencies]
@@ -289,14 +281,14 @@ explicit = true
289281

290282
[[tool.uv.index]]
291283
name = "vllm-cu129"
292-
url = "https://wheels.vllm.ai/0.20.2/cu129"
284+
url = "https://wheels.vllm.ai/0.23.0/cu129"
293285
explicit = true
294286

295287
[tool.uv.sources]
296288
skyrl-gym = { path = "./skyrl-gym", editable = true }
297289
# Match torch's CUDA variant (cu128).
298290
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "sys_platform == 'linux'" }
299-
# vllm 0.20.2's PyPI wheel needs CUDA 13 (libcudart.so.13); the cu129 wheel
291+
# vllm 0.23.0's PyPI wheel needs CUDA 13 (libcudart.so.13); the cu129 wheel
300292
# links libcudart.so.12, which torch+cu128 supplies.
301293
vllm = [
302294
{ index = "vllm-cu129", marker = "sys_platform == 'linux'" },

skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,13 @@ async def start_weight_update(self, is_checkpoint_format: bool = True):
378378
engine = self._get_engine()
379379
return await asyncio.to_thread(
380380
engine.collective_rpc,
381-
"start_weight_update",
381+
"skyrl_start_weight_update",
382382
args=(is_checkpoint_format,),
383383
)
384384

385385
async def finish_weight_update(self):
386386
engine = self._get_engine()
387-
return await asyncio.to_thread(engine.collective_rpc, "finish_weight_update")
387+
return await asyncio.to_thread(engine.collective_rpc, "skyrl_finish_weight_update")
388388

389389

390390
class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine):
@@ -608,13 +608,13 @@ async def _teardown_weight_receiver(self):
608608
async def start_weight_update(self, is_checkpoint_format: bool = True):
609609
engine = self._get_engine()
610610
return await engine.collective_rpc(
611-
"start_weight_update",
611+
"skyrl_start_weight_update",
612612
args=(is_checkpoint_format,),
613613
)
614614

615615
async def finish_weight_update(self):
616616
engine = self._get_engine()
617-
return await engine.collective_rpc("finish_weight_update")
617+
return await engine.collective_rpc("skyrl_finish_weight_update")
618618

619619
# ----------------------------------------
620620
# Methods for handling OpenAI API requests

skyrl/backends/skyrl_train/inference_servers/layerwise_reload.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
layerwise reload once per weight sync rather than once per chunk.
77
"""
88

9+
import inspect
10+
from collections.abc import Callable
911
from typing import TYPE_CHECKING
1012

1113
import torch
@@ -14,40 +16,61 @@
1416
from vllm.config import ModelConfig, VllmConfig
1517
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1618

17-
# Workaround for a vLLM layerwise-reload corruption affecting NemotronH/Mamba.
18-
# MambaMixer2 registers `conv_weights` as a non-persistent buffer that is a
19-
# view of `self.conv1d.weight.data` (shared storage). vLLM's reload code path
20-
# (model_executor/model_loader/reload/layerwise.py) materializes the buffer
21-
# into a fresh uninitialized GPU tensor and then runs
22-
# `kernel_conv_weights.data.copy_(fresh)` in `_copy_and_restore_kernel_tensors`.
23-
# Because the kernel buffer shares storage with `conv1d.weight.data`, this
24-
# writes garbage (NaN-bit-pattern bytes in bf16) into the conv1d weight,
25-
# corrupting all 23 Mamba layers after every weight sync.
26-
#
27-
# Adding "conv_weights" to vLLM's SKIP_TENSORS makes capture/restore/materialize
28-
# skip the buffer entirely, so the view stays intact and conv1d.weight is
29-
# preserved. Must be applied before `record_metadata_for_reloading` runs at
30-
# model construction; this module is imported by vLLM via
31-
# --worker-extension-cls before model init, so the import-time patch is
32-
# correctly ordered.
33-
# Remove this pending https://github.com/vllm-project/vllm/pull/42481 which should
34-
# be included in vLLM 0.21.0
35-
try:
36-
# Guarded import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
37-
from vllm.model_executor.model_loader.reload.meta import (
38-
SKIP_TENSORS as _VLLM_SKIP_TENSORS,
39-
)
40-
41-
_VLLM_SKIP_TENSORS.add("conv_weights")
42-
except ImportError:
43-
pass
19+
20+
def get_numel_loaded(weight_loader: Callable, args: inspect.BoundArguments) -> tuple[int, object]:
21+
"""
22+
Determine how many elements would be loaded by a weight loader call.
23+
24+
Args:
25+
weight_loader: used to load weights
26+
args: bound arguments to weight loader
27+
28+
Returns:
29+
number of elements loaded by the weight loader, the return value of the
30+
weight loader
31+
"""
32+
# Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
33+
from vllm.model_executor.model_loader.reload.meta import CopyCounter
34+
35+
with CopyCounter() as counter:
36+
return_value = weight_loader(*args.args, **args.kwargs)
37+
38+
# A weight loader fills a single destination parameter, so the number of
39+
# loaded elements is at most that parameter's size. Some loaders copy into
40+
# the parameter more than once -- e.g. ``composed_weight_loader`` runs an
41+
# in-place post-load transform (``param.copy_(fn(param))``) on top of the
42+
# initial copy -- which would make CopyCounter report twice the parameter
43+
# size. Over-counting inflates the layer's loaded-element total and can
44+
# finalize the layer before every parameter is loaded, silently dropping
45+
# the trailing parameter(s) (e.g. Mamba ``mixer.D``). Cap the count at the
46+
# destination size to keep the per-layer accounting correct.
47+
numel = counter.copied_numel
48+
param = args.arguments.get("param", None)
49+
if isinstance(param, torch.Tensor):
50+
numel = min(numel, param.numel())
51+
return numel, return_value
52+
53+
54+
def patch_numel_loaded():
55+
# vLLM's layerwise reload binds get_numel_loaded at import time
56+
# (`from .meta import get_numel_loaded`), so its call site at
57+
# layerwise.py uses the `layerwise` module's own binding. Rebind that
58+
# attribute to our patched version to substitute the symbol.
59+
from vllm.model_executor.model_loader.reload import layerwise as _layerwise
60+
from vllm.model_executor.model_loader.reload import meta as _meta
61+
62+
_layerwise.get_numel_loaded = get_numel_loaded
63+
_meta.get_numel_loaded = get_numel_loaded
64+
65+
66+
_PATCHED_LAYERWISE_NUMEL_LOADED = False
4467

4568

4669
class LayerwiseReloadWorkerMixin:
4770
"""Bracket a multi-chunk weight sync with one vLLM layerwise-reload init/finalize.
4871
49-
`start_weight_update` initializes the layerwise reload once; each chunk then loads
50-
its weights raw; `finish_weight_update` finalizes once over the whole weight set.
72+
`skyrl_start_weight_update` initializes the layerwise reload once; each chunk then loads
73+
its weights raw; `skyrl_finish_weight_update` finalizes once over the whole weight set.
5174
A per-chunk `reload_weights` is the wrong approach: it re-finalizes on every call
5275
and restores layers absent from that chunk, corrupting a multi-chunk sync.
5376
"""
@@ -57,7 +80,14 @@ class LayerwiseReloadWorkerMixin:
5780
model_config: "ModelConfig"
5881
device: torch.device
5982

60-
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
83+
# NOTE: named with a `skyrl_` prefix to avoid colliding with vLLM's own
84+
# Worker.start_weight_update / finish_weight_update (added in vllm-project/vllm
85+
# #39212, merge e3b65a5, shipped in vLLM 0.22.0+). vLLM injects the
86+
# worker-extension class as a *base* of Worker and asserts the extension
87+
# defines no attribute already present on Worker, so same-named methods abort
88+
# engine init. The skyrl_-prefixed variants keep SkyRL's IPC weight-sync path
89+
# (and the MoE set_current_vllm_config wrapping) intact alongside vLLM's native API.
90+
def skyrl_start_weight_update(self, is_checkpoint_format: bool = True) -> None:
6191
"""
6292
Prepare the model for a new weight update.
6393
@@ -74,10 +104,18 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
74104
"""
75105
if getattr(self, "_skyrl_weight_update_active", False):
76106
raise RuntimeError(
77-
"start_weight_update called while a weight update is "
78-
"already active. Call finish_weight_update first."
107+
"skyrl_start_weight_update called while a weight update is "
108+
"already active. Call skyrl_finish_weight_update first."
79109
)
80110

111+
# Ensure the get_numel_loaded patch is in effect before layerwise
112+
# reload runs.
113+
global _PATCHED_LAYERWISE_NUMEL_LOADED
114+
if not _PATCHED_LAYERWISE_NUMEL_LOADED:
115+
# use patched version, based on https://github.com/vllm-project/vllm/pull/44814
116+
patch_numel_loaded()
117+
_PATCHED_LAYERWISE_NUMEL_LOADED = True
118+
81119
if is_checkpoint_format:
82120
# Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
83121
from vllm.config import set_current_vllm_config
@@ -92,7 +130,7 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
92130
self._skyrl_is_checkpoint_format = is_checkpoint_format
93131
self._skyrl_weight_update_active = True
94132

95-
def finish_weight_update(self) -> None:
133+
def skyrl_finish_weight_update(self) -> None:
96134
"""
97135
Finalize the current weight update.
98136
@@ -101,7 +139,7 @@ def finish_weight_update(self) -> None:
101139
Must be called after all update_weights_ipc calls are done.
102140
"""
103141
if not getattr(self, "_skyrl_weight_update_active", False):
104-
raise RuntimeError("start_weight_update must be called before finish_weight_update.")
142+
raise RuntimeError("skyrl_start_weight_update must be called before skyrl_finish_weight_update.")
105143

106144
if self._skyrl_is_checkpoint_format:
107145
# Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.

skyrl/backends/skyrl_train/inference_servers/new_inference_worker_wrap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
enables chunked weight updates from training to inference using the
66
start/update/finish lifecycle:
77
8-
start_weight_update -> one or more update_weights_ipc -> finish_weight_update
8+
skyrl_start_weight_update -> one or more update_weights_ipc -> skyrl_finish_weight_update
99
1010
This separates the layerwise reload initialization/finalization from individual
1111
chunk transfers, allowing weights to be sent in bounded-memory chunks rather
@@ -40,9 +40,9 @@ class NewInferenceWorkerWrap(LayerwiseReloadWorkerMixin):
4040
vLLM worker extension for chunked weight sync (new inference path).
4141
4242
Provides a three-phase weight update protocol via collective_rpc:
43-
1. start_weight_update: Prepare model for receiving weights
43+
1. skyrl_start_weight_update: Prepare model for receiving weights
4444
2. update_weights_ipc: Receive and load one chunk of weights
45-
3. finish_weight_update: Finalize the model after all chunks
45+
3. skyrl_finish_weight_update: Finalize the model after all chunks
4646
4747
Attributes accessed from the host GPUWorker (via mixin inheritance):
4848
self.weight_transfer_engine
@@ -70,7 +70,7 @@ def update_weights_ipc(self, update_info: dict) -> None:
7070
- ipc_handles_pickled: b64(pickle({gpu_uuid: (func, args)}))
7171
"""
7272
if not getattr(self, "_skyrl_weight_update_active", False):
73-
raise RuntimeError("start_weight_update must be called before update_weights_ipc.")
73+
raise RuntimeError("skyrl_start_weight_update must be called before update_weights_ipc.")
7474

7575
if self.weight_transfer_engine is None:
7676
raise RuntimeError(
@@ -141,7 +141,7 @@ def update_weights_nccl(self, update_info: dict) -> None:
141141
https://github.com/vllm-project/vllm/pull/42577
142142
"""
143143
if not getattr(self, "_skyrl_weight_update_active", False):
144-
raise RuntimeError("start_weight_update must be called before update_weights_nccl.")
144+
raise RuntimeError("skyrl_start_weight_update must be called before update_weights_nccl.")
145145

146146
if self.weight_transfer_engine is None:
147147
raise RuntimeError(

skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,7 @@ async def start_weight_update(
11181118
"""
11191119
Start a new chunked weight update via /collective_rpc.
11201120
1121-
Calls the NewInferenceWorkerWrap.start_weight_update method on all
1121+
Calls the NewInferenceWorkerWrap.skyrl_start_weight_update method on all
11221122
workers. For checkpoint-format weights this initializes layerwise
11231123
reload. Must be called before any update_weights_ipc calls.
11241124
@@ -1132,7 +1132,7 @@ async def start_weight_update(
11321132
return await self._call_all_servers(
11331133
"/collective_rpc",
11341134
{
1135-
"method": "start_weight_update",
1135+
"method": "skyrl_start_weight_update",
11361136
"kwargs": {"is_checkpoint_format": is_checkpoint_format},
11371137
},
11381138
)
@@ -1145,8 +1145,8 @@ async def update_weights_ipc(
11451145
Send a single weight chunk via /collective_rpc.
11461146
11471147
Calls NewInferenceWorkerWrap.update_weights_ipc on all workers.
1148-
Can be called multiple times between start_weight_update and
1149-
finish_weight_update.
1148+
Can be called multiple times between skyrl_start_weight_update and
1149+
skyrl_finish_weight_update.
11501150
11511151
Args:
11521152
update_info: Dict with backend-specific update info (names,
@@ -1196,15 +1196,15 @@ async def finish_weight_update(self) -> Dict[str, Any]:
11961196
"""
11971197
Finish the current chunked weight update via /collective_rpc.
11981198
1199-
Calls NewInferenceWorkerWrap.finish_weight_update on all workers.
1199+
Calls NewInferenceWorkerWrap.skyrl_finish_weight_update on all workers.
12001200
For checkpoint-format weights, runs layerwise postprocessing.
12011201
12021202
Returns:
12031203
Dict mapping server_url to response.
12041204
"""
12051205
return await self._call_all_servers(
12061206
"/collective_rpc",
1207-
{"method": "finish_weight_update"},
1207+
{"method": "skyrl_finish_weight_update"},
12081208
)
12091209

12101210
async def load_lora_adapter(

0 commit comments

Comments
 (0)