Skip to content

Commit d35c02d

Browse files
S1ro1faaany
authored andcommitted
Feat: save_pretrained for tensor parallel (and other parallelisms) models (huggingface#37919)
* tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
1 parent 6e827ea commit d35c02d

File tree

7 files changed

+271
-12
lines changed

7 files changed

+271
-12
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
6161
return [single_size] * blocks
6262

6363

64+
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
65+
"""
66+
Get the TP style for a parameter from the TP plan.
67+
68+
The TP plan is a dictionary that maps parameter names to TP styles.
69+
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
70+
"""
71+
generic_param_name = re.sub(r"\d+", "*", parameter_name)
72+
if generic_param_name in tp_plan:
73+
return tp_plan[generic_param_name]
74+
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
75+
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
76+
else:
77+
return None
78+
79+
6480
str_to_torch_dtype = {
6581
"BOOL": torch.bool,
6682
"U8": torch.uint8,
@@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
138154
return tensor.to(str_to_torch_dtype[slice_dtype])
139155

140156

157+
def repack_weights(
158+
packed_parameter: torch.Tensor,
159+
sharded_dim: int, # The dimension index in the global tensor that was sharded
160+
world_size: int,
161+
num_blocks: int = 2,
162+
) -> torch.Tensor:
163+
"""
164+
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
165+
166+
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
167+
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
168+
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
169+
This is an inverse operation to get_packed_weights.
170+
171+
Args:
172+
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
173+
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
174+
world_size: The tensor parallel world size.
175+
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
176+
177+
Returns:
178+
The reordered tensor in canonical packed format.
179+
"""
180+
181+
if num_blocks != 2:
182+
raise ValueError(
183+
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
184+
)
185+
186+
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
187+
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
188+
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
189+
shard_chunk_size = original_block_size_on_dim // world_size
190+
191+
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
192+
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
193+
194+
tensor_view = packed_parameter.view(
195+
*prefix_shape,
196+
world_size,
197+
num_blocks,
198+
shard_chunk_size,
199+
*suffix_shape,
200+
)
201+
202+
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
203+
# This groups all chunks of G together, then all chunks of U together.
204+
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
205+
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
206+
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
207+
axis_ws_abs = len(prefix_shape)
208+
axis_npp_abs = len(prefix_shape) + 1
209+
210+
permute_order = list(range(tensor_view.ndim))
211+
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
212+
213+
tensor_permuted = tensor_view.permute(*permute_order)
214+
215+
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
216+
# The final shape should be the same as reconstructed_tensor.
217+
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
218+
219+
return final_ordered_tensor
220+
221+
141222
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
142223
if dim == 0:
143224
size_ = empty_param.shape[0]
@@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
578659
raise ValueError(f"Unsupported parallel style value: {style}")
579660

580661

662+
def convert_local_tensor_to_dtensor(
663+
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
664+
) -> DTensor:
665+
"""
666+
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
667+
"""
668+
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
669+
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
670+
if not tp_style:
671+
return parameter
672+
673+
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
674+
return parameter
675+
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
676+
if tp_style == "local_packed_rowwise":
677+
placements = [Shard(-1)]
678+
elif tp_style == "local_rowwise":
679+
if param_type == "bias":
680+
placements = [Replicate()]
681+
else:
682+
placements = [Shard(-1)]
683+
elif tp_style == "local_colwise":
684+
if param_type == "bias":
685+
placements = [Shard(-1)]
686+
else:
687+
placements = [Shard(-2)]
688+
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
689+
690+
691+
def replace_state_dict_local_with_dtensor(
692+
state_dict: dict[str, torch.Tensor],
693+
tp_plan: dict[str, str],
694+
device_mesh,
695+
) -> dict[str, torch.Tensor]:
696+
"""
697+
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
698+
"""
699+
for key, value in state_dict.items():
700+
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
701+
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
702+
return state_dict
703+
704+
581705
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
582706
"""
583707
Add hooks to the module holding the layer. Meaning:
@@ -632,13 +756,9 @@ def shard_and_distribute_module(
632756
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
633757
tp_plan = model._tp_plan
634758
module_to_tp = model.get_submodule(param_name)
635-
current_module_plan = None
636759
rank = int(rank)
637-
generic_param_name = re.sub(r"\d+", "*", parameter_name)
638-
if generic_param_name in tp_plan:
639-
current_module_plan = tp_plan[generic_param_name]
640-
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
641-
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
760+
761+
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
642762

643763
# Add hooks to the module if not done yet
644764
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)

src/transformers/modeling_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
from .integrations.sdpa_attention import sdpa_attention_forward
6464
from .integrations.tensor_parallel import (
6565
SUPPORTED_TP_STYLES,
66+
_get_parameter_tp_plan,
67+
repack_weights,
68+
replace_state_dict_local_with_dtensor,
6669
shard_and_distribute_module,
6770
verify_tp_plan,
6871
)
@@ -123,6 +126,7 @@
123126
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
124127
from .utils.import_utils import (
125128
ENV_VARS_TRUE_VALUES,
129+
is_huggingface_hub_greater_or_equal,
126130
is_sagemaker_mp_enabled,
127131
is_torch_fx_proxy,
128132
is_torchdynamo_compiling,
@@ -168,6 +172,9 @@
168172
_is_ds_init_called = False
169173
_torch_distributed_available = torch.distributed.is_available()
170174

175+
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
176+
from torch.distributed.tensor import DTensor
177+
171178

172179
def is_fsdp_enabled():
173180
return (
@@ -3413,6 +3420,12 @@ def save_pretrained(
34133420
if safe_serialization and not is_safetensors_available():
34143421
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
34153422

3423+
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
3424+
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
3425+
raise ImportError(
3426+
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
3427+
)
3428+
34163429
if os.path.isfile(save_directory):
34173430
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
34183431
return
@@ -3540,6 +3553,10 @@ def save_pretrained(
35403553
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
35413554
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
35423555
state_dict = self._fix_state_dict_keys_on_save(state_dict)
3556+
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
3557+
# therefore we replace them with DTensors that are equivalently sharded
3558+
if self._tp_size is not None:
3559+
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
35433560

35443561
if safe_serialization:
35453562
# Safetensors does not allow tensor aliasing.
@@ -3548,7 +3565,7 @@ def save_pretrained(
35483565
for name, tensor in state_dict.items():
35493566
# Sometimes in the state_dict we have non-tensor objects.
35503567
# e.g. in bitsandbytes we have some `str` objects in the state_dict
3551-
if isinstance(tensor, torch.Tensor):
3568+
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
35523569
ptrs[id_tensor_storage(tensor)].append(name)
35533570
else:
35543571
# In the non-tensor case, fall back to the pointer of the object itself
@@ -3658,7 +3675,14 @@ def save_pretrained(
36583675
for shard_file, tensors in filename_to_tensors:
36593676
shard = {}
36603677
for tensor in tensors:
3661-
shard[tensor] = state_dict[tensor].contiguous()
3678+
if isinstance(state_dict[tensor], DTensor):
3679+
full_tensor = state_dict[tensor].full_tensor()
3680+
# to get the correctly ordered tensor we need to repack if packed
3681+
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
3682+
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
3683+
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
3684+
else:
3685+
shard[tensor] = state_dict[tensor].contiguous()
36623686
# delete reference, see https://github.com/huggingface/transformers/pull/34890
36633687
del state_dict[tensor]
36643688

@@ -4606,6 +4630,7 @@ def _assign_original_dtype(module):
46064630

46074631
# record tp degree the model sharded to
46084632
model._tp_size = tp_size
4633+
model._device_mesh = device_mesh
46094634

46104635
# make sure token embedding weights are still tied if needed
46114636
model.tie_weights()

src/transformers/pytorch_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
296296
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
297297
non-overlapping lifetimes may have the same id.
298298
"""
299+
if is_torch_greater_or_equal_than_2_0:
300+
from torch.distributed.tensor import DTensor
301+
302+
if isinstance(tensor, DTensor):
303+
local_tensor = tensor.to_local()
304+
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
305+
299306
if tensor.device.type == "xla" and is_torch_xla_available():
300307
# NOTE: xla tensors dont have storage
301308
# use some other unique id to distinguish.

src/transformers/testing_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
is_grokadamw_available,
9898
is_hadamard_available,
9999
is_hqq_available,
100+
is_huggingface_hub_greater_or_equal,
100101
is_ipex_available,
101102
is_jieba_available,
102103
is_jinja_available,
@@ -542,6 +543,21 @@ def decorator(test_case):
542543
return decorator
543544

544545

546+
def require_huggingface_hub_greater_or_equal(version: str):
547+
"""
548+
Decorator marking a test that requires huggingface_hub version >= `version`.
549+
550+
These tests are skipped when huggingface_hub version is less than `version`.
551+
"""
552+
553+
def decorator(test_case):
554+
return unittest.skipUnless(
555+
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
556+
)(test_case)
557+
558+
return decorator
559+
560+
545561
def require_flash_attn(test_case):
546562
"""
547563
Decorator marking a test that requires Flash Attention.

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
is_habana_gaudi1,
168168
is_hadamard_available,
169169
is_hqq_available,
170+
is_huggingface_hub_greater_or_equal,
170171
is_in_notebook,
171172
is_ipex_available,
172173
is_jieba_available,

src/transformers/utils/import_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
10771077
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
10781078

10791079

1080+
@lru_cache()
1081+
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
1082+
if not _is_package_available("huggingface_hub"):
1083+
return False
1084+
1085+
if accept_dev:
1086+
return version.parse(
1087+
version.parse(importlib.metadata.version("huggingface_hub")).base_version
1088+
) >= version.parse(library_version)
1089+
else:
1090+
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)
1091+
1092+
10801093
def is_torchdistx_available():
10811094
return _torchdistx_available
10821095

0 commit comments

Comments
 (0)