Skip to content

Commit 5a527df

Browse files
jamesbrazaclaude
andauthored
[fix] Fix silent MoE corruption in the legacy multi-chunk weight sync (NovaSky-AI#1737)
Builds on NovaSky-AI#1685, which switched the legacy `WorkerWrap.load_weights` path to `reload_weights`. With that, `load_weights` ran a self-contained `reload_weights` per chunk, so vLLM's per-call `finalize_layerwise_processing` restored every layer absent from that chunk — silently corrupting any multi-chunk weight sync into MoE gibberish after the first sync ( NovaSky-AI#1680; upstream vllm-project/vllm#42821). This PR brackets the whole sync with a single layerwise-reload initialize/finalize via a shared `LayerwiseReloadWorkerMixin`, sharing this lifecycle with the 'new' inference path (`new_inference_worker_wrap.py`). Closes NovaSky-AI#1680. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ba94c5d commit 5a527df

9 files changed

Lines changed: 412 additions & 91 deletions

File tree

skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,12 @@ async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo")
354354
async def update_named_weights(self, request: WeightUpdateRequest):
355355
return await self._run_on_all_engines("update_named_weights", request=request)
356356

357+
async def start_weight_update(self, is_checkpoint_format: bool = True):
358+
return await self._run_on_all_engines("start_weight_update", is_checkpoint_format=is_checkpoint_format)
359+
360+
async def finish_weight_update(self):
361+
return await self._run_on_all_engines("finish_weight_update")
362+
357363
async def reset_prefix_cache(self):
358364
return await self._run_on_all_engines("reset_prefix_cache")
359365

skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo")
6969
async def update_named_weights(self, request: WeightUpdateRequest):
7070
return await self.inference_engine_actor.update_named_weights.remote(request)
7171

72+
async def start_weight_update(self, is_checkpoint_format: bool = True):
73+
return await self.inference_engine_actor.start_weight_update.remote(is_checkpoint_format=is_checkpoint_format)
74+
75+
async def finish_weight_update(self):
76+
return await self.inference_engine_actor.finish_weight_update.remote()
77+
7278
async def teardown(self):
7379
return await self.inference_engine_actor.teardown.remote()
7480

skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,18 @@ async def _teardown_weight_receiver(self):
372372
engine = self._get_engine()
373373
return await asyncio.to_thread(engine.collective_rpc, "teardown_weight_receiver")
374374

375+
async def start_weight_update(self, is_checkpoint_format: bool = True):
376+
engine = self._get_engine()
377+
return await asyncio.to_thread(
378+
engine.collective_rpc,
379+
"start_weight_update",
380+
args=(is_checkpoint_format,),
381+
)
382+
383+
async def finish_weight_update(self):
384+
engine = self._get_engine()
385+
return await asyncio.to_thread(engine.collective_rpc, "finish_weight_update")
386+
375387

376388
class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine):
377389
"""Asynchronous VLLM engine."""
@@ -591,6 +603,17 @@ async def _teardown_weight_receiver(self):
591603
engine = self._get_engine()
592604
return await engine.collective_rpc("teardown_weight_receiver")
593605

606+
async def start_weight_update(self, is_checkpoint_format: bool = True):
607+
engine = self._get_engine()
608+
return await engine.collective_rpc(
609+
"start_weight_update",
610+
args=(is_checkpoint_format,),
611+
)
612+
613+
async def finish_weight_update(self):
614+
engine = self._get_engine()
615+
return await engine.collective_rpc("finish_weight_update")
616+
594617
# ----------------------------------------
595618
# Methods for handling OpenAI API requests
596619
# ----------------------------------------
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Shared vLLM layerwise-reload lifecycle for SkyRL's vLLM worker-extension classes.
2+
3+
Provides `LayerwiseReloadWorkerMixin`, the start/finish bracket that both
4+
`vllm_worker.WorkerWrap` and
5+
`new_inference_worker_wrap.NewInferenceWorkerWrap` use to run vLLM's
6+
layerwise reload once per weight sync rather than once per chunk.
7+
"""
8+
9+
from typing import TYPE_CHECKING
10+
11+
import torch
12+
13+
if TYPE_CHECKING:
14+
from vllm.config import ModelConfig, VllmConfig
15+
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
16+
17+
# Workaround for a vLLM layerwise-reload corruption affecting NemotronH/Mamba.
18+
# MambaMixer2 registers `conv_weights` as a non-persistent buffer that is a
19+
# view of `self.conv1d.weight.data` (shared storage). vLLM's reload code path
20+
# (model_executor/model_loader/reload/layerwise.py) materializes the buffer
21+
# into a fresh uninitialized GPU tensor and then runs
22+
# `kernel_conv_weights.data.copy_(fresh)` in `_copy_and_restore_kernel_tensors`.
23+
# Because the kernel buffer shares storage with `conv1d.weight.data`, this
24+
# writes garbage (NaN-bit-pattern bytes in bf16) into the conv1d weight,
25+
# corrupting all 23 Mamba layers after every weight sync.
26+
#
27+
# Adding "conv_weights" to vLLM's SKIP_TENSORS makes capture/restore/materialize
28+
# skip the buffer entirely, so the view stays intact and conv1d.weight is
29+
# preserved. Must be applied before `record_metadata_for_reloading` runs at
30+
# model construction; this module is imported by vLLM via
31+
# --worker-extension-cls before model init, so the import-time patch is
32+
# correctly ordered.
33+
# Remove this pending https://github.com/vllm-project/vllm/pull/42481 which should
34+
# be included in vLLM 0.21.0
35+
try:
36+
# Guarded import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
37+
from vllm.model_executor.model_loader.reload.meta import (
38+
SKIP_TENSORS as _VLLM_SKIP_TENSORS,
39+
)
40+
41+
_VLLM_SKIP_TENSORS.add("conv_weights")
42+
except ImportError:
43+
pass
44+
45+
46+
class LayerwiseReloadWorkerMixin:
47+
"""Bracket a multi-chunk weight sync with one vLLM layerwise-reload init/finalize.
48+
49+
`start_weight_update` initializes the layerwise reload once; each chunk then loads
50+
its weights raw; `finish_weight_update` finalizes once over the whole weight set.
51+
A per-chunk `reload_weights` is the wrong approach: it re-finalizes on every call
52+
and restores layers absent from that chunk, corrupting a multi-chunk sync.
53+
"""
54+
55+
vllm_config: "VllmConfig"
56+
model_runner: "GPUModelRunner"
57+
model_config: "ModelConfig"
58+
device: torch.device
59+
60+
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
61+
"""
62+
Prepare the model for a new weight update.
63+
64+
For checkpoint-format weights, initializes the layerwise reload
65+
machinery which moves layers to meta device and wraps weight loaders
66+
to defer processing until all weights for each layer are loaded.
67+
68+
Must be called before any update_weights_ipc calls.
69+
70+
Args:
71+
is_checkpoint_format: True if incoming weights are in checkpoint
72+
format (need layerwise processing). False if weights are
73+
already in kernel format (direct copy).
74+
"""
75+
if getattr(self, "_skyrl_weight_update_active", False):
76+
raise RuntimeError(
77+
"start_weight_update called while a weight update is "
78+
"already active. Call finish_weight_update first."
79+
)
80+
81+
if is_checkpoint_format:
82+
# Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
83+
from vllm.config import set_current_vllm_config
84+
from vllm.model_executor.model_loader.reload import (
85+
initialize_layerwise_reload,
86+
)
87+
88+
model = self.model_runner.model
89+
with set_current_vllm_config(self.vllm_config), torch.device(self.device):
90+
initialize_layerwise_reload(model)
91+
92+
self._skyrl_is_checkpoint_format = is_checkpoint_format
93+
self._skyrl_weight_update_active = True
94+
95+
def finish_weight_update(self) -> None:
96+
"""
97+
Finalize the current weight update.
98+
99+
For checkpoint-format weights, runs layerwise postprocessing
100+
(quantization repacking, attention weight processing, etc.).
101+
Must be called after all update_weights_ipc calls are done.
102+
"""
103+
if not getattr(self, "_skyrl_weight_update_active", False):
104+
raise RuntimeError("start_weight_update must be called before finish_weight_update.")
105+
106+
if self._skyrl_is_checkpoint_format:
107+
# Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
108+
from vllm.config import set_current_vllm_config
109+
from vllm.model_executor.model_loader.reload import (
110+
finalize_layerwise_reload,
111+
)
112+
113+
model = self.model_runner.model
114+
with set_current_vllm_config(self.vllm_config), torch.device(self.device):
115+
finalize_layerwise_reload(model, self.model_config)
116+
117+
self._skyrl_weight_update_active = False
118+
self._skyrl_is_checkpoint_format = True

skyrl/backends/skyrl_train/inference_servers/new_inference_worker_wrap.py

Lines changed: 4 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -28,37 +28,14 @@
2828

2929
import torch
3030

31-
# Workaround for a vLLM layerwise-reload corruption affecting NemotronH/Mamba.
32-
# MambaMixer2 registers `conv_weights` as a non-persistent buffer that is a
33-
# view of `self.conv1d.weight.data` (shared storage). vLLM's reload code path
34-
# (model_executor/model_loader/reload/layerwise.py) materializes the buffer
35-
# into a fresh uninitialized GPU tensor and then runs
36-
# `kernel_conv_weights.data.copy_(fresh)` in `_copy_and_restore_kernel_tensors`.
37-
# Because the kernel buffer shares storage with `conv1d.weight.data`, this
38-
# writes garbage (NaN-bit-pattern bytes in bf16) into the conv1d weight,
39-
# corrupting all 23 Mamba layers after every weight sync.
40-
#
41-
# Adding "conv_weights" to vLLM's SKIP_TENSORS makes capture/restore/materialize
42-
# skip the buffer entirely, so the view stays intact and conv1d.weight is
43-
# preserved. Must be applied before `record_metadata_for_reloading` runs at
44-
# model construction; this module is imported by vLLM via
45-
# --worker-extension-cls before model init, so the import-time patch is
46-
# correctly ordered.
47-
# Remove this pending https://github.com/vllm-project/vllm/pull/42481 which should
48-
# be included in vLLM 0.21.0
49-
try:
50-
from vllm.model_executor.model_loader.reload.meta import (
51-
SKIP_TENSORS as _VLLM_SKIP_TENSORS,
52-
)
53-
54-
_VLLM_SKIP_TENSORS.add("conv_weights")
55-
except ImportError:
56-
pass
31+
from skyrl.backends.skyrl_train.inference_servers.layerwise_reload import (
32+
LayerwiseReloadWorkerMixin,
33+
)
5734

5835
VLLM_NEW_INFERENCE_WORKER_EXTENSION_CLS = f"{__name__}.NewInferenceWorkerWrap"
5936

6037

61-
class NewInferenceWorkerWrap:
38+
class NewInferenceWorkerWrap(LayerwiseReloadWorkerMixin):
6239
"""
6340
vLLM worker extension for chunked weight sync (new inference path).
6441
@@ -74,40 +51,6 @@ class NewInferenceWorkerWrap:
7451
self.device
7552
"""
7653

77-
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
78-
"""
79-
Prepare the model for a new weight update.
80-
81-
For checkpoint-format weights, initializes the layerwise reload
82-
machinery which moves layers to meta device and wraps weight loaders
83-
to defer processing until all weights for each layer are loaded.
84-
85-
Must be called before any update_weights_ipc calls.
86-
87-
Args:
88-
is_checkpoint_format: True if incoming weights are in checkpoint
89-
format (need layerwise processing). False if weights are
90-
already in kernel format (direct copy).
91-
"""
92-
if getattr(self, "_skyrl_weight_update_active", False):
93-
raise RuntimeError(
94-
"start_weight_update called while a weight update is "
95-
"already active. Call finish_weight_update first."
96-
)
97-
98-
if is_checkpoint_format:
99-
from vllm.config import set_current_vllm_config
100-
from vllm.model_executor.model_loader.reload import (
101-
initialize_layerwise_reload,
102-
)
103-
104-
model = self.model_runner.model
105-
with set_current_vllm_config(self.vllm_config), torch.device(self.device):
106-
initialize_layerwise_reload(model)
107-
108-
self._skyrl_is_checkpoint_format = is_checkpoint_format
109-
self._skyrl_weight_update_active = True
110-
11154
def update_weights_ipc(self, update_info: dict) -> None:
11255
"""
11356
Receive and load a single chunk of weights.
@@ -217,27 +160,3 @@ def update_weights_nccl(self, update_info: dict) -> None:
217160
)
218161

219162
torch.accelerator.synchronize()
220-
221-
def finish_weight_update(self) -> None:
222-
"""
223-
Finalize the current weight update.
224-
225-
For checkpoint-format weights, runs layerwise postprocessing
226-
(quantization repacking, attention weight processing, etc.).
227-
Must be called after all update_weights_ipc calls are done.
228-
"""
229-
if not getattr(self, "_skyrl_weight_update_active", False):
230-
raise RuntimeError("start_weight_update must be called before finish_weight_update.")
231-
232-
if self._skyrl_is_checkpoint_format:
233-
from vllm.config import set_current_vllm_config
234-
from vllm.model_executor.model_loader.reload import (
235-
finalize_layerwise_reload,
236-
)
237-
238-
model = self.model_runner.model
239-
with set_current_vllm_config(self.vllm_config), torch.device(self.device):
240-
finalize_layerwise_reload(model, self.model_config)
241-
242-
self._skyrl_weight_update_active = False
243-
self._skyrl_is_checkpoint_format = True

skyrl/backends/skyrl_train/inference_servers/vllm_worker.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818

1919
import torch
2020

21+
from skyrl.backends.skyrl_train.inference_servers.layerwise_reload import (
22+
LayerwiseReloadWorkerMixin,
23+
)
24+
2125
# Path to this worker extension class for use in CLI args (derived from module path)
2226
VLLM_WORKER_EXTENSION_CLS = f"{__name__}.WorkerWrap"
2327

2428

25-
class WorkerWrap:
29+
class WorkerWrap(LayerwiseReloadWorkerMixin):
2630
"""
2731
vLLM worker extension for SkyRL weight synchronization.
2832
@@ -32,7 +36,9 @@ class WorkerWrap:
3236
3337
Methods:
3438
init_weight_update_communicator: Initialize the weight receiver
35-
load_weights: Receive and load weights from trainer
39+
start_weight_update: Begin a sync; initialize vLLM layerwise reload once
40+
load_weights: Receive and load one chunk of weights from trainer
41+
finish_weight_update: End a sync; finalize vLLM layerwise reload once
3642
teardown_weight_receiver: Clean up weight receiver resources
3743
"""
3844

@@ -73,9 +79,15 @@ def init_weight_update_communicator(self, init_info: bytes):
7379

7480
def load_weights(self, request: bytes) -> None:
7581
"""
76-
Load weights using the receiver.
82+
Load one chunk of weights using the receiver.
7783
78-
This method is called via collective_rpc from the weight loader.
84+
Called via collective_rpc from the weight loader, once per chunk.
85+
When the sender brackets the sync with start_weight_update / finish_weight_update,
86+
the chunk is loaded raw and the single finalize runs vLLM's post-load weight
87+
processing exactly once over the whole weight set.
88+
Without a bracket, it falls back to a self-contained reload_weights
89+
(initialize + load + finalize in this one call), correct when the call
90+
carries the whole model so finalize sees every layer and restores none.
7991
8092
Args:
8193
request: Pickled bytes of WeightUpdateRequest.
@@ -92,8 +104,17 @@ def load_weights(self, request: bytes) -> None:
92104
for name, tensor in self._weight_receiver.receive_weights(request):
93105
weight_list.append((name, tensor))
94106

107+
weight_update_bracketed = getattr(self, "_skyrl_weight_update_active", False)
95108
with torch.device(self.device), set_current_vllm_config(self.vllm_config):
96-
self.model_runner.reload_weights(weights_iterator=iter(weight_list))
109+
if weight_update_bracketed:
110+
self.model_runner.model.load_weights(weights=weight_list)
111+
else:
112+
self.model_runner.reload_weights(weights_iterator=iter(weight_list))
113+
114+
if weight_update_bracketed:
115+
# Finish consuming IPC-backed tensors before the sender drops them on
116+
# its next barrier; matches NewInferenceWorkerWrap.update_weights_ipc
117+
torch.accelerator.synchronize()
97118

98119
for weight in weight_list:
99120
del weight

skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ async def _send_chunks_legacy(self, chunks: Iterable[WeightChunk]) -> None:
230230
if rank == 0:
231231
assert self._model_update_group is not None, "Rank 0 must have model_update_group"
232232

233+
# Bracket the whole sync with one layerwise-reload initialize/finalize so
234+
# per-chunk reloads don't restore non-chunk layers; see `vllm_worker.py.WorkerWrap` docs
235+
if rank == 0:
236+
await self._inference_client.start_weight_update(is_checkpoint_format=True)
237+
torch.distributed.barrier()
238+
233239
# All ranks iterate through chunks (weight extraction may involve collective ops)
234240
for chunk in chunks:
235241
# Only rank 0 sends request to inference engines
@@ -264,6 +270,10 @@ def broadcast_packed(t, group):
264270

265271
torch.distributed.barrier()
266272

273+
if rank == 0:
274+
await self._inference_client.finish_weight_update()
275+
torch.distributed.barrier()
276+
267277
def teardown(self) -> None:
268278
"""Destroy the process group used for weight transfer."""
269279
if self._model_update_group is not None and isinstance(

0 commit comments

Comments
 (0)