Skip to content

Commit 2a4392b

Browse files
author
adityakamat24
committed
Add AutoWeightsLoader utility for simplified weight loading
- Add AutoWeightsLoader class to sglang.srt.models.utils - Implements automatic module/parameter detection during weight loading - Supports skip_prefixes, skip_substrs for filtering weights - Supports ignore_unexpected_prefixes/suffixes for optional weights - Auto-skips common unused weights (rotary_emb.inv_freq, etc.) - Respects custom weight_loader methods on parameters - Supports nested modules with custom load_weights methods - Reduces code duplication across model implementations Addresses #11864
1 parent 2031569 commit 2a4392b

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed

python/sglang/srt/models/utils.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,258 @@
1212
# limitations under the License.
1313
# ==============================================================================
1414

15+
import itertools
16+
import logging
17+
from typing import Callable, Iterable, Optional, Tuple
18+
1519
import torch
20+
import torch.nn as nn
1621

1722
from sglang.srt.layers.radix_attention import RadixAttention
1823
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24+
from sglang.srt.model_loader.weight_utils import default_weight_loader
1925
from sglang.srt.utils import is_cuda
2026

27+
logger = logging.getLogger(__name__)
28+
2129
_is_cuda = is_cuda()
2230

2331

32+
class AutoWeightsLoader:
33+
"""
34+
Helper class to load weights into a torch.nn.Module. It automatically
35+
detects child modules and parameters while iterating weights only once.
36+
37+
This simplifies model code by abstracting the common weight loading logic,
38+
reducing code duplication across different model implementations.
39+
40+
Adapted from vLLM's AutoWeightsLoader implementation.
41+
42+
Args:
43+
module: The root module to load weights into
44+
skip_prefixes: List of weight name prefixes to skip
45+
skip_substrs: List of substrings to skip in weight names
46+
ignore_unexpected_prefixes: List of prefixes for unexpected weights to ignore
47+
ignore_unexpected_suffixes: List of suffixes for unexpected weights to ignore
48+
49+
Example:
50+
>>> def load_weights(self, weights):
51+
... loader = AutoWeightsLoader(
52+
... self,
53+
... skip_prefixes=["lm_head"] if self.config.tie_word_embeddings else [],
54+
... skip_substrs=["rotary_emb.inv_freq"],
55+
... )
56+
... return loader.load_weights(weights)
57+
"""
58+
59+
# Common weights that should be skipped (e.g., ColossalAI rotary embeddings)
60+
ROTARY_EMBEDS_UNUSED_WEIGHTS = [
61+
"rotary_emb.inv_freq",
62+
"rotary_emb.cos_cached",
63+
"rotary_emb.sin_cached",
64+
]
65+
66+
def __init__(
67+
self,
68+
module: nn.Module,
69+
*,
70+
skip_prefixes: Optional[list] = None,
71+
skip_substrs: Optional[list] = None,
72+
ignore_unexpected_prefixes: Optional[list] = None,
73+
ignore_unexpected_suffixes: Optional[list] = None,
74+
) -> None:
75+
self.module = module
76+
self.skip_prefixes = skip_prefixes or []
77+
self.skip_substrs = skip_substrs or []
78+
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
79+
self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
80+
# Always skip common rotary embedding weights
81+
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
82+
83+
def _groupby_prefix(
84+
self,
85+
weights: Iterable[Tuple[str, torch.Tensor]],
86+
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
87+
"""Group weights by their first prefix component (before the first dot)."""
88+
weights_by_parts = (
89+
(weight_name.split(".", 1), weight_data)
90+
for weight_name, weight_data in weights
91+
)
92+
for prefix, group in itertools.groupby(
93+
weights_by_parts, key=lambda x: x[0][0]
94+
):
95+
yield (
96+
prefix,
97+
(
98+
("" if len(parts) == 1 else parts[1], weights_data)
99+
for parts, weights_data in group
100+
),
101+
)
102+
103+
def _get_qualname(self, prefix: str, rest: str) -> str:
104+
"""Construct fully qualified name from prefix and rest."""
105+
if prefix == "":
106+
return rest
107+
if rest == "":
108+
return prefix
109+
return ".".join((prefix, rest))
110+
111+
def _can_skip(self, qualname: str) -> bool:
112+
"""Check if parameter should be skipped based on skip rules."""
113+
return any(qualname.startswith(p) for p in self.skip_prefixes) or any(
114+
substr in qualname for substr in self.skip_substrs
115+
)
116+
117+
def _can_ignore_unexpected(self, qualname: str) -> bool:
118+
"""Check if unexpected weight can be ignored based on ignore rules."""
119+
starts_with_ignored = any(
120+
qualname.startswith(p) for p in self.ignore_unexpected_prefixes
121+
)
122+
ends_with_ignored = any(
123+
qualname.endswith(s) for s in self.ignore_unexpected_suffixes
124+
)
125+
return starts_with_ignored or ends_with_ignored
126+
127+
def _load_param(
128+
self,
129+
base_prefix: str,
130+
param: nn.Parameter,
131+
weights: Iterable[Tuple[str, torch.Tensor]],
132+
) -> Iterable[str]:
133+
"""Load weights into a single parameter."""
134+
for weight_name, weight_data in weights:
135+
weight_qualname = self._get_qualname(base_prefix, weight_name)
136+
137+
if self._can_skip(weight_qualname):
138+
logger.debug("Skipping weight %s", weight_qualname)
139+
continue
140+
141+
if weight_name != "":
142+
if self._can_ignore_unexpected(weight_qualname):
143+
logger.debug("Ignoring unexpected weight %s", weight_qualname)
144+
continue
145+
raise ValueError(
146+
f"Attempted to load nested weight '{weight_qualname}' "
147+
f"into a single parameter '{base_prefix}'"
148+
)
149+
150+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
151+
weight_loader(param, weight_data)
152+
logger.debug(
153+
"Loaded weight %s with shape %s", weight_qualname, param.shape
154+
)
155+
yield weight_qualname
156+
157+
def _add_loadable_non_param_tensors(
158+
self, module: nn.Module, child_params: dict
159+
):
160+
"""Add tensor names not in model params (e.g., batchnorm statistics)."""
161+
if isinstance(
162+
module,
163+
(
164+
nn.BatchNorm1d,
165+
nn.BatchNorm2d,
166+
nn.BatchNorm3d,
167+
nn.LazyBatchNorm1d,
168+
nn.LazyBatchNorm2d,
169+
nn.LazyBatchNorm3d,
170+
nn.SyncBatchNorm,
171+
),
172+
):
173+
module_state_dict = module.state_dict()
174+
for stat_name in ("running_mean", "running_var", "num_batches_tracked"):
175+
if stat_name in module_state_dict:
176+
child_params[stat_name] = module_state_dict[stat_name]
177+
178+
def _load_module(
179+
self,
180+
base_prefix: str,
181+
module: nn.Module,
182+
weights: Iterable[Tuple[str, torch.Tensor]],
183+
) -> Iterable[str]:
184+
"""Recursively load weights into a module and its children."""
185+
# If module has a custom load_weights method, use it
186+
if module != self.module:
187+
module_load_weights = getattr(module, "load_weights", None)
188+
if callable(module_load_weights):
189+
loaded_params = module_load_weights(weights)
190+
if loaded_params is None:
191+
logger.warning(
192+
"Unable to collect loaded parameters for module %s. "
193+
"Module.load_weights() should return an iterable of "
194+
"loaded parameter names.",
195+
module,
196+
)
197+
else:
198+
yield from (
199+
self._get_qualname(base_prefix, x) for x in loaded_params
200+
)
201+
return
202+
203+
# Get child modules and parameters
204+
child_modules = dict(module.named_children())
205+
child_params = dict(module.named_parameters(recurse=False))
206+
self._add_loadable_non_param_tensors(module, child_params)
207+
208+
# Process weights grouped by prefix
209+
for child_prefix, child_weights in self._groupby_prefix(weights):
210+
prefix = self._get_qualname(base_prefix, child_prefix)
211+
212+
if child_prefix in child_modules:
213+
if self._can_skip(prefix + "."):
214+
logger.debug("Skipping module %s", prefix)
215+
continue
216+
yield from self._load_module(
217+
prefix, child_modules[child_prefix], child_weights
218+
)
219+
elif child_prefix in child_params:
220+
if self._can_skip(prefix):
221+
logger.debug("Skipping param %s", prefix)
222+
continue
223+
yield from self._load_param(
224+
prefix, child_params[child_prefix], child_weights
225+
)
226+
else:
227+
# Check if we should skip or ignore this missing parameter
228+
can_skip_module = self._can_skip(prefix + ".")
229+
can_skip_param = self._can_skip(prefix)
230+
if can_skip_module or can_skip_param:
231+
logger.debug("Skipping missing %s", prefix)
232+
continue
233+
234+
can_ignore_module = self._can_ignore_unexpected(prefix + ".")
235+
can_ignore_param = self._can_ignore_unexpected(prefix)
236+
if can_ignore_module or can_ignore_param:
237+
logger.debug("Ignoring missing %s", prefix)
238+
continue
239+
240+
msg = (
241+
f"There is no module or parameter named '{prefix}' "
242+
f"in {type(self.module).__name__}"
243+
)
244+
raise ValueError(msg)
245+
246+
def load_weights(
247+
self,
248+
weights: Iterable[Tuple[str, torch.Tensor]],
249+
) -> set:
250+
"""
251+
Load weights into the module.
252+
253+
Args:
254+
weights: Iterable of (name, tensor) tuples to load
255+
256+
Returns:
257+
Set of weight names that were successfully loaded
258+
"""
259+
# Filter out skippable weights early
260+
weights = (
261+
(name, weight) for name, weight in weights if not self._can_skip(name)
262+
)
263+
autoloaded_weights = set(self._load_module("", self.module, weights))
264+
return autoloaded_weights
265+
266+
24267
if _is_cuda:
25268
from sgl_kernel import FusedSetKVBufferArg
26269

0 commit comments

Comments
 (0)