| 
12 | 12 | # limitations under the License.  | 
13 | 13 | # ==============================================================================  | 
14 | 14 | 
 
  | 
 | 15 | +import itertools  | 
 | 16 | +import logging  | 
 | 17 | +from typing import Callable, Iterable, Optional, Tuple  | 
 | 18 | + | 
15 | 19 | import torch  | 
 | 20 | +import torch.nn as nn  | 
16 | 21 | 
 
  | 
17 | 22 | from sglang.srt.layers.radix_attention import RadixAttention  | 
18 | 23 | from sglang.srt.model_executor.forward_batch_info import ForwardBatch  | 
 | 24 | +from sglang.srt.model_loader.weight_utils import default_weight_loader  | 
19 | 25 | from sglang.srt.utils import is_cuda  | 
20 | 26 | 
 
  | 
 | 27 | +logger = logging.getLogger(__name__)  | 
 | 28 | + | 
21 | 29 | _is_cuda = is_cuda()  | 
22 | 30 | 
 
  | 
23 | 31 | 
 
  | 
 | 32 | +class AutoWeightsLoader:  | 
 | 33 | +    """  | 
 | 34 | +    Helper class to load weights into a torch.nn.Module. It automatically  | 
 | 35 | +    detects child modules and parameters while iterating weights only once.  | 
 | 36 | +
  | 
 | 37 | +    This simplifies model code by abstracting the common weight loading logic,  | 
 | 38 | +    reducing code duplication across different model implementations.  | 
 | 39 | +
  | 
 | 40 | +    Adapted from vLLM's AutoWeightsLoader implementation.  | 
 | 41 | +
  | 
 | 42 | +    Args:  | 
 | 43 | +        module: The root module to load weights into  | 
 | 44 | +        skip_prefixes: List of weight name prefixes to skip  | 
 | 45 | +        skip_substrs: List of substrings to skip in weight names  | 
 | 46 | +        ignore_unexpected_prefixes: List of prefixes for unexpected weights to ignore  | 
 | 47 | +        ignore_unexpected_suffixes: List of suffixes for unexpected weights to ignore  | 
 | 48 | +
  | 
 | 49 | +    Example:  | 
 | 50 | +        >>> def load_weights(self, weights):  | 
 | 51 | +        ...     loader = AutoWeightsLoader(  | 
 | 52 | +        ...         self,  | 
 | 53 | +        ...         skip_prefixes=["lm_head"] if self.config.tie_word_embeddings else [],  | 
 | 54 | +        ...         skip_substrs=["rotary_emb.inv_freq"],  | 
 | 55 | +        ...     )  | 
 | 56 | +        ...     return loader.load_weights(weights)  | 
 | 57 | +    """  | 
 | 58 | + | 
 | 59 | +    # Common weights that should be skipped (e.g., ColossalAI rotary embeddings)  | 
 | 60 | +    ROTARY_EMBEDS_UNUSED_WEIGHTS = [  | 
 | 61 | +        "rotary_emb.inv_freq",  | 
 | 62 | +        "rotary_emb.cos_cached",  | 
 | 63 | +        "rotary_emb.sin_cached",  | 
 | 64 | +    ]  | 
 | 65 | + | 
 | 66 | +    def __init__(  | 
 | 67 | +        self,  | 
 | 68 | +        module: nn.Module,  | 
 | 69 | +        *,  | 
 | 70 | +        skip_prefixes: Optional[list] = None,  | 
 | 71 | +        skip_substrs: Optional[list] = None,  | 
 | 72 | +        ignore_unexpected_prefixes: Optional[list] = None,  | 
 | 73 | +        ignore_unexpected_suffixes: Optional[list] = None,  | 
 | 74 | +    ) -> None:  | 
 | 75 | +        self.module = module  | 
 | 76 | +        self.skip_prefixes = skip_prefixes or []  | 
 | 77 | +        self.skip_substrs = skip_substrs or []  | 
 | 78 | +        self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []  | 
 | 79 | +        self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []  | 
 | 80 | +        # Always skip common rotary embedding weights  | 
 | 81 | +        self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS  | 
 | 82 | + | 
 | 83 | +    def _groupby_prefix(  | 
 | 84 | +        self,  | 
 | 85 | +        weights: Iterable[Tuple[str, torch.Tensor]],  | 
 | 86 | +    ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:  | 
 | 87 | +        """Group weights by their first prefix component (before the first dot)."""  | 
 | 88 | +        weights_by_parts = (  | 
 | 89 | +            (weight_name.split(".", 1), weight_data)  | 
 | 90 | +            for weight_name, weight_data in weights  | 
 | 91 | +        )  | 
 | 92 | +        for prefix, group in itertools.groupby(  | 
 | 93 | +            weights_by_parts, key=lambda x: x[0][0]  | 
 | 94 | +        ):  | 
 | 95 | +            yield (  | 
 | 96 | +                prefix,  | 
 | 97 | +                (  | 
 | 98 | +                    ("" if len(parts) == 1 else parts[1], weights_data)  | 
 | 99 | +                    for parts, weights_data in group  | 
 | 100 | +                ),  | 
 | 101 | +            )  | 
 | 102 | + | 
 | 103 | +    def _get_qualname(self, prefix: str, rest: str) -> str:  | 
 | 104 | +        """Construct fully qualified name from prefix and rest."""  | 
 | 105 | +        if prefix == "":  | 
 | 106 | +            return rest  | 
 | 107 | +        if rest == "":  | 
 | 108 | +            return prefix  | 
 | 109 | +        return ".".join((prefix, rest))  | 
 | 110 | + | 
 | 111 | +    def _can_skip(self, qualname: str) -> bool:  | 
 | 112 | +        """Check if parameter should be skipped based on skip rules."""  | 
 | 113 | +        return any(qualname.startswith(p) for p in self.skip_prefixes) or any(  | 
 | 114 | +            substr in qualname for substr in self.skip_substrs  | 
 | 115 | +        )  | 
 | 116 | + | 
 | 117 | +    def _can_ignore_unexpected(self, qualname: str) -> bool:  | 
 | 118 | +        """Check if unexpected weight can be ignored based on ignore rules."""  | 
 | 119 | +        starts_with_ignored = any(  | 
 | 120 | +            qualname.startswith(p) for p in self.ignore_unexpected_prefixes  | 
 | 121 | +        )  | 
 | 122 | +        ends_with_ignored = any(  | 
 | 123 | +            qualname.endswith(s) for s in self.ignore_unexpected_suffixes  | 
 | 124 | +        )  | 
 | 125 | +        return starts_with_ignored or ends_with_ignored  | 
 | 126 | + | 
 | 127 | +    def _load_param(  | 
 | 128 | +        self,  | 
 | 129 | +        base_prefix: str,  | 
 | 130 | +        param: nn.Parameter,  | 
 | 131 | +        weights: Iterable[Tuple[str, torch.Tensor]],  | 
 | 132 | +    ) -> Iterable[str]:  | 
 | 133 | +        """Load weights into a single parameter."""  | 
 | 134 | +        for weight_name, weight_data in weights:  | 
 | 135 | +            weight_qualname = self._get_qualname(base_prefix, weight_name)  | 
 | 136 | + | 
 | 137 | +            if self._can_skip(weight_qualname):  | 
 | 138 | +                logger.debug("Skipping weight %s", weight_qualname)  | 
 | 139 | +                continue  | 
 | 140 | + | 
 | 141 | +            if weight_name != "":  | 
 | 142 | +                if self._can_ignore_unexpected(weight_qualname):  | 
 | 143 | +                    logger.debug("Ignoring unexpected weight %s", weight_qualname)  | 
 | 144 | +                    continue  | 
 | 145 | +                raise ValueError(  | 
 | 146 | +                    f"Attempted to load nested weight '{weight_qualname}' "  | 
 | 147 | +                    f"into a single parameter '{base_prefix}'"  | 
 | 148 | +                )  | 
 | 149 | + | 
 | 150 | +            weight_loader = getattr(param, "weight_loader", default_weight_loader)  | 
 | 151 | +            weight_loader(param, weight_data)  | 
 | 152 | +            logger.debug(  | 
 | 153 | +                "Loaded weight %s with shape %s", weight_qualname, param.shape  | 
 | 154 | +            )  | 
 | 155 | +            yield weight_qualname  | 
 | 156 | + | 
 | 157 | +    def _add_loadable_non_param_tensors(  | 
 | 158 | +        self, module: nn.Module, child_params: dict  | 
 | 159 | +    ):  | 
 | 160 | +        """Add tensor names not in model params (e.g., batchnorm statistics)."""  | 
 | 161 | +        if isinstance(  | 
 | 162 | +            module,  | 
 | 163 | +            (  | 
 | 164 | +                nn.BatchNorm1d,  | 
 | 165 | +                nn.BatchNorm2d,  | 
 | 166 | +                nn.BatchNorm3d,  | 
 | 167 | +                nn.LazyBatchNorm1d,  | 
 | 168 | +                nn.LazyBatchNorm2d,  | 
 | 169 | +                nn.LazyBatchNorm3d,  | 
 | 170 | +                nn.SyncBatchNorm,  | 
 | 171 | +            ),  | 
 | 172 | +        ):  | 
 | 173 | +            module_state_dict = module.state_dict()  | 
 | 174 | +            for stat_name in ("running_mean", "running_var", "num_batches_tracked"):  | 
 | 175 | +                if stat_name in module_state_dict:  | 
 | 176 | +                    child_params[stat_name] = module_state_dict[stat_name]  | 
 | 177 | + | 
 | 178 | +    def _load_module(  | 
 | 179 | +        self,  | 
 | 180 | +        base_prefix: str,  | 
 | 181 | +        module: nn.Module,  | 
 | 182 | +        weights: Iterable[Tuple[str, torch.Tensor]],  | 
 | 183 | +    ) -> Iterable[str]:  | 
 | 184 | +        """Recursively load weights into a module and its children."""  | 
 | 185 | +        # If module has a custom load_weights method, use it  | 
 | 186 | +        if module != self.module:  | 
 | 187 | +            module_load_weights = getattr(module, "load_weights", None)  | 
 | 188 | +            if callable(module_load_weights):  | 
 | 189 | +                loaded_params = module_load_weights(weights)  | 
 | 190 | +                if loaded_params is None:  | 
 | 191 | +                    logger.warning(  | 
 | 192 | +                        "Unable to collect loaded parameters for module %s. "  | 
 | 193 | +                        "Module.load_weights() should return an iterable of "  | 
 | 194 | +                        "loaded parameter names.",  | 
 | 195 | +                        module,  | 
 | 196 | +                    )  | 
 | 197 | +                else:  | 
 | 198 | +                    yield from (  | 
 | 199 | +                        self._get_qualname(base_prefix, x) for x in loaded_params  | 
 | 200 | +                    )  | 
 | 201 | +                return  | 
 | 202 | + | 
 | 203 | +        # Get child modules and parameters  | 
 | 204 | +        child_modules = dict(module.named_children())  | 
 | 205 | +        child_params = dict(module.named_parameters(recurse=False))  | 
 | 206 | +        self._add_loadable_non_param_tensors(module, child_params)  | 
 | 207 | + | 
 | 208 | +        # Process weights grouped by prefix  | 
 | 209 | +        for child_prefix, child_weights in self._groupby_prefix(weights):  | 
 | 210 | +            prefix = self._get_qualname(base_prefix, child_prefix)  | 
 | 211 | + | 
 | 212 | +            if child_prefix in child_modules:  | 
 | 213 | +                if self._can_skip(prefix + "."):  | 
 | 214 | +                    logger.debug("Skipping module %s", prefix)  | 
 | 215 | +                    continue  | 
 | 216 | +                yield from self._load_module(  | 
 | 217 | +                    prefix, child_modules[child_prefix], child_weights  | 
 | 218 | +                )  | 
 | 219 | +            elif child_prefix in child_params:  | 
 | 220 | +                if self._can_skip(prefix):  | 
 | 221 | +                    logger.debug("Skipping param %s", prefix)  | 
 | 222 | +                    continue  | 
 | 223 | +                yield from self._load_param(  | 
 | 224 | +                    prefix, child_params[child_prefix], child_weights  | 
 | 225 | +                )  | 
 | 226 | +            else:  | 
 | 227 | +                # Check if we should skip or ignore this missing parameter  | 
 | 228 | +                can_skip_module = self._can_skip(prefix + ".")  | 
 | 229 | +                can_skip_param = self._can_skip(prefix)  | 
 | 230 | +                if can_skip_module or can_skip_param:  | 
 | 231 | +                    logger.debug("Skipping missing %s", prefix)  | 
 | 232 | +                    continue  | 
 | 233 | + | 
 | 234 | +                can_ignore_module = self._can_ignore_unexpected(prefix + ".")  | 
 | 235 | +                can_ignore_param = self._can_ignore_unexpected(prefix)  | 
 | 236 | +                if can_ignore_module or can_ignore_param:  | 
 | 237 | +                    logger.debug("Ignoring missing %s", prefix)  | 
 | 238 | +                    continue  | 
 | 239 | + | 
 | 240 | +                msg = (  | 
 | 241 | +                    f"There is no module or parameter named '{prefix}' "  | 
 | 242 | +                    f"in {type(self.module).__name__}"  | 
 | 243 | +                )  | 
 | 244 | +                raise ValueError(msg)  | 
 | 245 | + | 
 | 246 | +    def load_weights(  | 
 | 247 | +        self,  | 
 | 248 | +        weights: Iterable[Tuple[str, torch.Tensor]],  | 
 | 249 | +    ) -> set:  | 
 | 250 | +        """  | 
 | 251 | +        Load weights into the module.  | 
 | 252 | +
  | 
 | 253 | +        Args:  | 
 | 254 | +            weights: Iterable of (name, tensor) tuples to load  | 
 | 255 | +
  | 
 | 256 | +        Returns:  | 
 | 257 | +            Set of weight names that were successfully loaded  | 
 | 258 | +        """  | 
 | 259 | +        # Filter out skippable weights early  | 
 | 260 | +        weights = (  | 
 | 261 | +            (name, weight) for name, weight in weights if not self._can_skip(name)  | 
 | 262 | +        )  | 
 | 263 | +        autoloaded_weights = set(self._load_module("", self.module, weights))  | 
 | 264 | +        return autoloaded_weights  | 
 | 265 | + | 
 | 266 | + | 
24 | 267 | if _is_cuda:  | 
25 | 268 |     from sgl_kernel import FusedSetKVBufferArg  | 
26 | 269 | 
 
  | 
 | 
0 commit comments