66layerwise reload once per weight sync rather than once per chunk.
77"""
88
9+ import inspect
10+ from collections .abc import Callable
911from typing import TYPE_CHECKING
1012
1113import torch
1416 from vllm .config import ModelConfig , VllmConfig
1517 from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
1618
17- # Workaround for a vLLM layerwise-reload corruption affecting NemotronH/Mamba.
18- # MambaMixer2 registers `conv_weights` as a non-persistent buffer that is a
19- # view of `self.conv1d.weight.data` (shared storage). vLLM's reload code path
20- # (model_executor/model_loader/reload/layerwise.py) materializes the buffer
21- # into a fresh uninitialized GPU tensor and then runs
22- # `kernel_conv_weights.data.copy_(fresh)` in `_copy_and_restore_kernel_tensors`.
23- # Because the kernel buffer shares storage with `conv1d.weight.data`, this
24- # writes garbage (NaN-bit-pattern bytes in bf16) into the conv1d weight,
25- # corrupting all 23 Mamba layers after every weight sync.
26- #
27- # Adding "conv_weights" to vLLM's SKIP_TENSORS makes capture/restore/materialize
28- # skip the buffer entirely, so the view stays intact and conv1d.weight is
29- # preserved. Must be applied before `record_metadata_for_reloading` runs at
30- # model construction; this module is imported by vLLM via
31- # --worker-extension-cls before model init, so the import-time patch is
32- # correctly ordered.
33- # Remove this pending https://github.com/vllm-project/vllm/pull/42481 which should
34- # be included in vLLM 0.21.0
35- try :
36- # Guarded import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
37- from vllm .model_executor .model_loader .reload .meta import (
38- SKIP_TENSORS as _VLLM_SKIP_TENSORS ,
39- )
40-
41- _VLLM_SKIP_TENSORS .add ("conv_weights" )
42- except ImportError :
43- pass
19+
20+ def get_numel_loaded (weight_loader : Callable , args : inspect .BoundArguments ) -> tuple [int , object ]:
21+ """
22+ Determine how many elements would be loaded by a weight loader call.
23+
24+ Args:
25+ weight_loader: used to load weights
26+ args: bound arguments to weight loader
27+
28+ Returns:
29+ number of elements loaded by the weight loader, the return value of the
30+ weight loader
31+ """
32+ # Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
33+ from vllm .model_executor .model_loader .reload .meta import CopyCounter
34+
35+ with CopyCounter () as counter :
36+ return_value = weight_loader (* args .args , ** args .kwargs )
37+
38+ # A weight loader fills a single destination parameter, so the number of
39+ # loaded elements is at most that parameter's size. Some loaders copy into
40+ # the parameter more than once -- e.g. ``composed_weight_loader`` runs an
41+ # in-place post-load transform (``param.copy_(fn(param))``) on top of the
42+ # initial copy -- which would make CopyCounter report twice the parameter
43+ # size. Over-counting inflates the layer's loaded-element total and can
44+ # finalize the layer before every parameter is loaded, silently dropping
45+ # the trailing parameter(s) (e.g. Mamba ``mixer.D``). Cap the count at the
46+ # destination size to keep the per-layer accounting correct.
47+ numel = counter .copied_numel
48+ param = args .arguments .get ("param" , None )
49+ if isinstance (param , torch .Tensor ):
50+ numel = min (numel , param .numel ())
51+ return numel , return_value
52+
53+
54+ def patch_numel_loaded ():
55+ # vLLM's layerwise reload binds get_numel_loaded at import time
56+ # (`from .meta import get_numel_loaded`), so its call site at
57+ # layerwise.py uses the `layerwise` module's own binding. Rebind that
58+ # attribute to our patched version to substitute the symbol.
59+ from vllm .model_executor .model_loader .reload import layerwise as _layerwise
60+ from vllm .model_executor .model_loader .reload import meta as _meta
61+
62+ _layerwise .get_numel_loaded = get_numel_loaded
63+ _meta .get_numel_loaded = get_numel_loaded
64+
65+
66+ _PATCHED_LAYERWISE_NUMEL_LOADED = False
4467
4568
4669class LayerwiseReloadWorkerMixin :
4770 """Bracket a multi-chunk weight sync with one vLLM layerwise-reload init/finalize.
4871
49- `start_weight_update ` initializes the layerwise reload once; each chunk then loads
50- its weights raw; `finish_weight_update ` finalizes once over the whole weight set.
72+ `skyrl_start_weight_update ` initializes the layerwise reload once; each chunk then loads
73+ its weights raw; `skyrl_finish_weight_update ` finalizes once over the whole weight set.
5174 A per-chunk `reload_weights` is the wrong approach: it re-finalizes on every call
5275 and restores layers absent from that chunk, corrupting a multi-chunk sync.
5376 """
@@ -57,7 +80,14 @@ class LayerwiseReloadWorkerMixin:
5780 model_config : "ModelConfig"
5881 device : torch .device
5982
60- def start_weight_update (self , is_checkpoint_format : bool = True ) -> None :
83+ # NOTE: named with a `skyrl_` prefix to avoid colliding with vLLM's own
84+ # Worker.start_weight_update / finish_weight_update (added in vllm-project/vllm
85+ # #39212, merge e3b65a5, shipped in vLLM 0.22.0+). vLLM injects the
86+ # worker-extension class as a *base* of Worker and asserts the extension
87+ # defines no attribute already present on Worker, so same-named methods abort
88+ # engine init. The skyrl_-prefixed variants keep SkyRL's IPC weight-sync path
89+ # (and the MoE set_current_vllm_config wrapping) intact alongside vLLM's native API.
90+ def skyrl_start_weight_update (self , is_checkpoint_format : bool = True ) -> None :
6191 """
6292 Prepare the model for a new weight update.
6393
@@ -74,10 +104,18 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
74104 """
75105 if getattr (self , "_skyrl_weight_update_active" , False ):
76106 raise RuntimeError (
77- "start_weight_update called while a weight update is "
78- "already active. Call finish_weight_update first."
107+ "skyrl_start_weight_update called while a weight update is "
108+ "already active. Call skyrl_finish_weight_update first."
79109 )
80110
111+ # Ensure the get_numel_loaded patch is in effect before layerwise
112+ # reload runs.
113+ global _PATCHED_LAYERWISE_NUMEL_LOADED
114+ if not _PATCHED_LAYERWISE_NUMEL_LOADED :
115+ # use patched version, based on https://github.com/vllm-project/vllm/pull/44814
116+ patch_numel_loaded ()
117+ _PATCHED_LAYERWISE_NUMEL_LOADED = True
118+
81119 if is_checkpoint_format :
82120 # Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
83121 from vllm .config import set_current_vllm_config
@@ -92,7 +130,7 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
92130 self ._skyrl_is_checkpoint_format = is_checkpoint_format
93131 self ._skyrl_weight_update_active = True
94132
95- def finish_weight_update (self ) -> None :
133+ def skyrl_finish_weight_update (self ) -> None :
96134 """
97135 Finalize the current weight update.
98136
@@ -101,7 +139,7 @@ def finish_weight_update(self) -> None:
101139 Must be called after all update_weights_ipc calls are done.
102140 """
103141 if not getattr (self , "_skyrl_weight_update_active" , False ):
104- raise RuntimeError ("start_weight_update must be called before finish_weight_update ." )
142+ raise RuntimeError ("skyrl_start_weight_update must be called before skyrl_finish_weight_update ." )
105143
106144 if self ._skyrl_is_checkpoint_format :
107145 # Lazy import: vllm is a Linux-only optional dependency, so this module stays importable on macOS / CI.
0 commit comments