Skip to content

Commit e569f9a

Browse files
authored
Merge branch 'main' into tp-model_saving-fix
2 parents db55a17 + 08f3677 commit e569f9a

File tree

4 files changed

+222
-17
lines changed

4 files changed

+222
-17
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
16+
device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
17+
somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
18+
19+
The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
20+
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
21+
"""
22+
23+
from contextlib import contextmanager
24+
25+
from ..utils import is_torch_available, logging
26+
27+
28+
if is_torch_available():
29+
import torch
30+
import torch.nn as nn
31+
32+
33+
logger = logging.get_logger(__name__)
34+
35+
36+
@contextmanager
37+
def init_empty_weights(include_buffers: bool = False):
38+
"""
39+
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
40+
empty model. Useful when just initializing the model would blow the available RAM.
41+
42+
Args:
43+
include_buffers (`bool`, *optional*):
44+
Whether or not to also put all buffers on the meta device while initializing.
45+
46+
Example:
47+
48+
```python
49+
import torch.nn as nn
50+
from accelerate import init_empty_weights
51+
52+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
53+
with init_empty_weights():
54+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
55+
```
56+
57+
<Tip warning={true}>
58+
59+
Any model created under this context manager has no weights. As such you can't do something like
60+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
61+
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
62+
called.
63+
64+
</Tip>
65+
"""
66+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
67+
yield f
68+
69+
70+
@contextmanager
71+
def init_on_device(device: "torch.device", include_buffers: bool = False):
72+
"""
73+
A context manager under which models are initialized with all parameters on the specified device.
74+
75+
Args:
76+
device (`torch.device`):
77+
Device to initialize all parameters on.
78+
include_buffers (`bool`, *optional*):
79+
Whether or not to also put all buffers on the meta device while initializing.
80+
81+
Example:
82+
83+
```python
84+
import torch.nn as nn
85+
from accelerate import init_on_device
86+
87+
with init_on_device(device=torch.device("cuda")):
88+
tst = nn.Linear(100, 100) # on `cuda` device
89+
```
90+
"""
91+
if include_buffers:
92+
with device:
93+
yield
94+
return
95+
96+
old_register_parameter = nn.Module.register_parameter
97+
if include_buffers:
98+
old_register_buffer = nn.Module.register_buffer
99+
100+
def register_empty_parameter(module, name, param):
101+
old_register_parameter(module, name, param)
102+
if param is not None:
103+
param_cls = type(module._parameters[name])
104+
kwargs = module._parameters[name].__dict__
105+
kwargs["requires_grad"] = param.requires_grad
106+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
107+
108+
def register_empty_buffer(module, name, buffer, persistent=True):
109+
old_register_buffer(module, name, buffer, persistent=persistent)
110+
if buffer is not None:
111+
module._buffers[name] = module._buffers[name].to(device)
112+
113+
# Patch tensor creation
114+
if include_buffers:
115+
tensor_constructors_to_patch = {
116+
torch_function_name: getattr(torch, torch_function_name)
117+
for torch_function_name in ["empty", "zeros", "ones", "full"]
118+
}
119+
else:
120+
tensor_constructors_to_patch = {}
121+
122+
def patch_tensor_constructor(fn):
123+
def wrapper(*args, **kwargs):
124+
kwargs["device"] = device
125+
return fn(*args, **kwargs)
126+
127+
return wrapper
128+
129+
try:
130+
nn.Module.register_parameter = register_empty_parameter
131+
if include_buffers:
132+
nn.Module.register_buffer = register_empty_buffer
133+
for torch_function_name in tensor_constructors_to_patch.keys():
134+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
135+
yield
136+
finally:
137+
nn.Module.register_parameter = old_register_parameter
138+
if include_buffers:
139+
nn.Module.register_buffer = old_register_buffer
140+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
141+
setattr(torch, torch_function_name, old_torch_function)
142+
143+
144+
def find_tied_parameters(model: "nn.Module", **kwargs):
145+
"""
146+
Find the tied parameters in a given model.
147+
148+
<Tip warning={true}>
149+
150+
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
151+
them.
152+
153+
</Tip>
154+
155+
Args:
156+
model (`torch.nn.Module`): The model to inspect.
157+
158+
Returns:
159+
List[List[str]]: A list of lists of parameter names being all tied together.
160+
161+
Example:
162+
163+
```py
164+
>>> from collections import OrderedDict
165+
>>> import torch.nn as nn
166+
167+
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
168+
>>> model.linear2.weight = model.linear1.weight
169+
>>> find_tied_parameters(model)
170+
[['linear1.weight', 'linear2.weight']]
171+
```
172+
"""
173+
174+
# get ALL model parameters and thier names
175+
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
176+
177+
# get ONLY unique named parameters,
178+
# if parameter is tied and have multiple names, it will be included only once
179+
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
180+
181+
# the difference of the two sets will give us the tied parameters
182+
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
183+
184+
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
185+
# which names refer to the same parameter. To identify this, we need to group them together.
186+
tied_param_groups = {}
187+
for tied_param_name in tied_param_names:
188+
tied_param = all_named_parameters[tied_param_name]
189+
for param_name, param in no_duplicate_named_parameters.items():
190+
# compare if parameters are the same, if so, group thier names together
191+
if param is tied_param:
192+
if param_name not in tied_param_groups:
193+
tied_param_groups[param_name] = []
194+
tied_param_groups[param_name].append(tied_param_name)
195+
196+
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]

src/transformers/integrations/flex_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def make_flex_block_causal_mask(
100100
Returns:
101101
BlockMask
102102
"""
103+
batch_size, total_seq_len = attention_mask_2d.shape
104+
if not key_length:
105+
key_length = total_seq_len
106+
if not query_length:
107+
query_length = total_seq_len
103108
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
104109
device = attention_mask_2d.device
105110
document_ids = attention_mask_2d.clone()
@@ -139,7 +144,7 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
139144
mask_mod = causal_mask_mod
140145
return create_block_causal_mask_flex(
141146
mask_mod=mask_mod,
142-
B=1,
147+
B=batch_size,
143148
H=None, # attention head
144149
Q_LEN=query_length,
145150
KV_LEN=key_length,

src/transformers/modeling_utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from .dynamic_module_utils import custom_object_save
5858
from .generation import CompileConfig, GenerationConfig, GenerationMixin
5959
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
60+
from .integrations.accelerate import find_tied_parameters, init_empty_weights
6061
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
6162
from .integrations.flash_attention import flash_attention_forward
6263
from .integrations.flex_attention import flex_attention_forward
@@ -131,12 +132,11 @@
131132

132133

133134
if is_accelerate_available():
134-
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
135+
from accelerate import dispatch_model, infer_auto_device_map
135136
from accelerate.hooks import add_hook_to_module
136137
from accelerate.utils import (
137138
check_tied_parameters_on_same_device,
138139
extract_model_from_parallel,
139-
find_tied_parameters,
140140
get_balanced_memory,
141141
get_max_memory,
142142
load_offloaded_weights,
@@ -3730,19 +3730,14 @@ def float(self, *args):
37303730

37313731
@classmethod
37323732
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
3733-
# With deepspeed, we cannot initialize the model on meta device
37343733
if is_deepspeed_zero3_enabled():
37353734
init_contexts = [no_init_weights()]
3735+
# We cannot initialize the model on meta device with deepspeed when not quantized
37363736
if not is_quantized and not _is_ds_init_called:
37373737
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
3738-
init_contexts.extend(
3739-
[
3740-
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
3741-
set_zero3_state(),
3742-
]
3743-
)
3738+
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
37443739
elif is_quantized:
3745-
init_contexts.append(set_quantized_state())
3740+
init_contexts.extend([init_empty_weights(), set_quantized_state()])
37463741
else:
37473742
init_contexts = [no_init_weights(), init_empty_weights()]
37483743

@@ -4151,6 +4146,10 @@ def from_pretrained(
41514146
if device_map is not None:
41524147
if is_deepspeed_zero3_enabled():
41534148
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
4149+
if not is_accelerate_available():
4150+
raise ValueError(
4151+
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
4152+
)
41544153

41554154
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
41564155
if load_in_4bit or load_in_8bit:
@@ -4811,7 +4810,11 @@ def _load_pretrained_model(
48114810
continue
48124811

48134812
map_location = "cpu"
4814-
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
4813+
if (
4814+
shard_file.endswith(".safetensors")
4815+
and not is_hqq_or_bnb
4816+
and not (is_deepspeed_zero3_enabled() and not is_quantized)
4817+
):
48154818
map_location = "meta"
48164819
elif (
48174820
device_map is not None
@@ -4833,7 +4836,7 @@ def _load_pretrained_model(
48334836
# Fix the key names
48344837
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
48354838

4836-
if is_deepspeed_zero3_enabled():
4839+
if is_deepspeed_zero3_enabled() and not is_quantized:
48374840
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
48384841
# Skip it with fsdp on ranks other than 0
48394842
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def forward(
356356
attn_scales = (
357357
torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
358358
)
359-
attn_scales = attn_scales.view((*input_shape, 1, 1))
359+
attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1
360360
query_states = (query_states * attn_scales).to(query_states.dtype)
361361

362362
query_states = query_states.transpose(1, 2)
@@ -692,6 +692,7 @@ def forward(
692692
position_ids,
693693
past_key_values,
694694
output_attentions,
695+
False, # output_router_logits is False
695696
use_cache,
696697
cache_position,
697698
freq_cis,
@@ -1375,6 +1376,7 @@ def forward(
13751376
layer_outputs = self._gradient_checkpointing_func(
13761377
encoder_layer.__call__,
13771378
hidden_states,
1379+
freqs_ci,
13781380
attention_mask,
13791381
output_attentions,
13801382
)
@@ -1445,7 +1447,7 @@ def forward(self, hidden_states):
14451447

14461448
class Llama4VisionModel(Llama4PreTrainedModel):
14471449
base_model_prefix = "vision_model"
1448-
_no_split_modules = ["Llama4VisionAttention"]
1450+
_no_split_modules = ["Llama4VisionEncoderLayer"]
14491451
config_class = Llama4VisionConfig
14501452

14511453
def __init__(self, config: Llama4VisionConfig):
@@ -1754,8 +1756,7 @@ def forward(
17541756
)
17551757

17561758
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
1757-
inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat)
1758-
1759+
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
17591760
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
17601761

17611762
outputs = self.language_model(

0 commit comments

Comments
 (0)