|
| 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +""" |
| 15 | +Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta |
| 16 | +device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is |
| 17 | +somewhat still a soft dependency, we copy the functions here to be used natively in Transformers. |
| 18 | +
|
| 19 | +The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the |
| 20 | +`find_tied_parameters` was copied from `accelerate.utils.modeling.py` |
| 21 | +""" |
| 22 | + |
| 23 | +from contextlib import contextmanager |
| 24 | + |
| 25 | +from ..utils import is_torch_available, logging |
| 26 | + |
| 27 | + |
| 28 | +if is_torch_available(): |
| 29 | + import torch |
| 30 | + import torch.nn as nn |
| 31 | + |
| 32 | + |
| 33 | +logger = logging.get_logger(__name__) |
| 34 | + |
| 35 | + |
| 36 | +@contextmanager |
| 37 | +def init_empty_weights(include_buffers: bool = False): |
| 38 | + """ |
| 39 | + A context manager under which models are initialized with all parameters on the meta device, therefore creating an |
| 40 | + empty model. Useful when just initializing the model would blow the available RAM. |
| 41 | +
|
| 42 | + Args: |
| 43 | + include_buffers (`bool`, *optional*): |
| 44 | + Whether or not to also put all buffers on the meta device while initializing. |
| 45 | +
|
| 46 | + Example: |
| 47 | +
|
| 48 | + ```python |
| 49 | + import torch.nn as nn |
| 50 | + from accelerate import init_empty_weights |
| 51 | +
|
| 52 | + # Initialize a model with 100 billions parameters in no time and without using any RAM. |
| 53 | + with init_empty_weights(): |
| 54 | + tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) |
| 55 | + ``` |
| 56 | +
|
| 57 | + <Tip warning={true}> |
| 58 | +
|
| 59 | + Any model created under this context manager has no weights. As such you can't do something like |
| 60 | + `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. |
| 61 | + Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not |
| 62 | + called. |
| 63 | +
|
| 64 | + </Tip> |
| 65 | + """ |
| 66 | + with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: |
| 67 | + yield f |
| 68 | + |
| 69 | + |
| 70 | +@contextmanager |
| 71 | +def init_on_device(device: "torch.device", include_buffers: bool = False): |
| 72 | + """ |
| 73 | + A context manager under which models are initialized with all parameters on the specified device. |
| 74 | +
|
| 75 | + Args: |
| 76 | + device (`torch.device`): |
| 77 | + Device to initialize all parameters on. |
| 78 | + include_buffers (`bool`, *optional*): |
| 79 | + Whether or not to also put all buffers on the meta device while initializing. |
| 80 | +
|
| 81 | + Example: |
| 82 | +
|
| 83 | + ```python |
| 84 | + import torch.nn as nn |
| 85 | + from accelerate import init_on_device |
| 86 | +
|
| 87 | + with init_on_device(device=torch.device("cuda")): |
| 88 | + tst = nn.Linear(100, 100) # on `cuda` device |
| 89 | + ``` |
| 90 | + """ |
| 91 | + if include_buffers: |
| 92 | + with device: |
| 93 | + yield |
| 94 | + return |
| 95 | + |
| 96 | + old_register_parameter = nn.Module.register_parameter |
| 97 | + if include_buffers: |
| 98 | + old_register_buffer = nn.Module.register_buffer |
| 99 | + |
| 100 | + def register_empty_parameter(module, name, param): |
| 101 | + old_register_parameter(module, name, param) |
| 102 | + if param is not None: |
| 103 | + param_cls = type(module._parameters[name]) |
| 104 | + kwargs = module._parameters[name].__dict__ |
| 105 | + kwargs["requires_grad"] = param.requires_grad |
| 106 | + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) |
| 107 | + |
| 108 | + def register_empty_buffer(module, name, buffer, persistent=True): |
| 109 | + old_register_buffer(module, name, buffer, persistent=persistent) |
| 110 | + if buffer is not None: |
| 111 | + module._buffers[name] = module._buffers[name].to(device) |
| 112 | + |
| 113 | + # Patch tensor creation |
| 114 | + if include_buffers: |
| 115 | + tensor_constructors_to_patch = { |
| 116 | + torch_function_name: getattr(torch, torch_function_name) |
| 117 | + for torch_function_name in ["empty", "zeros", "ones", "full"] |
| 118 | + } |
| 119 | + else: |
| 120 | + tensor_constructors_to_patch = {} |
| 121 | + |
| 122 | + def patch_tensor_constructor(fn): |
| 123 | + def wrapper(*args, **kwargs): |
| 124 | + kwargs["device"] = device |
| 125 | + return fn(*args, **kwargs) |
| 126 | + |
| 127 | + return wrapper |
| 128 | + |
| 129 | + try: |
| 130 | + nn.Module.register_parameter = register_empty_parameter |
| 131 | + if include_buffers: |
| 132 | + nn.Module.register_buffer = register_empty_buffer |
| 133 | + for torch_function_name in tensor_constructors_to_patch.keys(): |
| 134 | + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) |
| 135 | + yield |
| 136 | + finally: |
| 137 | + nn.Module.register_parameter = old_register_parameter |
| 138 | + if include_buffers: |
| 139 | + nn.Module.register_buffer = old_register_buffer |
| 140 | + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): |
| 141 | + setattr(torch, torch_function_name, old_torch_function) |
| 142 | + |
| 143 | + |
| 144 | +def find_tied_parameters(model: "nn.Module", **kwargs): |
| 145 | + """ |
| 146 | + Find the tied parameters in a given model. |
| 147 | +
|
| 148 | + <Tip warning={true}> |
| 149 | +
|
| 150 | + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore |
| 151 | + them. |
| 152 | +
|
| 153 | + </Tip> |
| 154 | +
|
| 155 | + Args: |
| 156 | + model (`torch.nn.Module`): The model to inspect. |
| 157 | +
|
| 158 | + Returns: |
| 159 | + List[List[str]]: A list of lists of parameter names being all tied together. |
| 160 | +
|
| 161 | + Example: |
| 162 | +
|
| 163 | + ```py |
| 164 | + >>> from collections import OrderedDict |
| 165 | + >>> import torch.nn as nn |
| 166 | +
|
| 167 | + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) |
| 168 | + >>> model.linear2.weight = model.linear1.weight |
| 169 | + >>> find_tied_parameters(model) |
| 170 | + [['linear1.weight', 'linear2.weight']] |
| 171 | + ``` |
| 172 | + """ |
| 173 | + |
| 174 | + # get ALL model parameters and thier names |
| 175 | + all_named_parameters = dict(model.named_parameters(remove_duplicate=False)) |
| 176 | + |
| 177 | + # get ONLY unique named parameters, |
| 178 | + # if parameter is tied and have multiple names, it will be included only once |
| 179 | + no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True)) |
| 180 | + |
| 181 | + # the difference of the two sets will give us the tied parameters |
| 182 | + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) |
| 183 | + |
| 184 | + # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know |
| 185 | + # which names refer to the same parameter. To identify this, we need to group them together. |
| 186 | + tied_param_groups = {} |
| 187 | + for tied_param_name in tied_param_names: |
| 188 | + tied_param = all_named_parameters[tied_param_name] |
| 189 | + for param_name, param in no_duplicate_named_parameters.items(): |
| 190 | + # compare if parameters are the same, if so, group thier names together |
| 191 | + if param is tied_param: |
| 192 | + if param_name not in tied_param_groups: |
| 193 | + tied_param_groups[param_name] = [] |
| 194 | + tied_param_groups[param_name].append(tied_param_name) |
| 195 | + |
| 196 | + return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] |
0 commit comments