Skip to content
118 changes: 112 additions & 6 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks


def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
"""
Get the TP style for a parameter from the TP plan.

The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
else:
return None


str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
Expand Down Expand Up @@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
return tensor.to(str_to_torch_dtype[slice_dtype])


def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is mega useful!


For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.

Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).

Returns:
The reordered tensor in canonical packed format.
"""

if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)

actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size

prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]

tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)

# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1

permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]

tensor_permuted = tensor_view.permute(*permute_order)

# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)

return final_ordered_tensor


def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0:
size_ = empty_param.shape[0]
Expand Down Expand Up @@ -578,6 +659,35 @@ def translate_to_torch_parallel_style(style: str):
raise ValueError(f"Unsupported parallel style value: {style}")


def convert_local_tensor_to_dtensor(
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
) -> DTensor:
"""
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
"""
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
if not tp_style:
return parameter

if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
return parameter
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
if tp_style == "local_packed_rowwise":
placements = [Shard(-1)]
elif tp_style == "local_rowwise":
if param_type == "bias":
placements = [Replicate()]
else:
placements = [Shard(-1)]
elif tp_style == "local_colwise":
if param_type == "bias":
placements = [Shard(-1)]
else:
placements = [Shard(-2)]
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)


def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
"""
Add hooks to the module holding the layer. Meaning:
Expand Down Expand Up @@ -632,13 +742,9 @@ def shard_and_distribute_module(
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]

current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)

# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
Expand Down
40 changes: 38 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
convert_local_tensor_to_dtensor,
repack_weights,
shard_and_distribute_module,
)
from .loss.loss_utils import LOSS_MAPPING
Expand Down Expand Up @@ -121,6 +124,7 @@
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
Expand Down Expand Up @@ -166,6 +170,9 @@
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()

if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor


def is_fsdp_enabled():
return (
Expand Down Expand Up @@ -3368,6 +3375,12 @@ def save_pretrained(
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")

# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
raise ImportError(
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
)

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -3483,6 +3496,9 @@ def save_pretrained(
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict)
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
# therefore we replace them with DTensors that are equivalently sharded
state_dict = self._replace_state_dict_local_with_dtensor(state_dict)

if safe_serialization:
# Safetensors does not allow tensor aliasing.
Expand All @@ -3491,7 +3507,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
Expand Down Expand Up @@ -3601,7 +3617,14 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

Expand Down Expand Up @@ -4530,6 +4553,7 @@ def _assign_original_dtype(module):

# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -4717,6 +4741,18 @@ def _fix_state_dict_keys_on_save(self, state_dict):
"""
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}

def _replace_state_dict_local_with_dtensor(self, state_dict):
"""
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make saving possible.
"""
if self._tp_size is None:
return state_dict
# TODO: optimize this to avoid iterating over all
for key, value in state_dict.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as much as possible should be hidden from this file and in TP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and moved to tensor_parallel.py

if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we are gonna iterate over all the weights when saving anyways not sure we need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we need to do this before split_torch_state_dict_into_shards is called, as that needs to have local tensors as dtensors to properly work. We iterate to save later.

return state_dict

@classmethod
def _load_pretrained_model(
cls,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if is_torch_greater_or_equal_than_2_0:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes

if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
is_grokadamw_available,
is_hadamard_available,
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -540,6 +541,21 @@ def decorator(test_case):
return decorator


def require_huggingface_hub_greater_or_equal(version: str):
"""
Decorator marking a test that requires huggingface_hub version >= `version`.

These tests are skipped when huggingface_hub version is less than `version`.
"""

def decorator(test_case):
return unittest.skipUnless(
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
)(test_case)

return decorator


def require_flash_attn(test_case):
"""
Decorator marking a test that requires Flash Attention.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
is_habana_gaudi1,
is_hadamard_available,
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_in_notebook,
is_ipex_available,
is_jieba_available,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)


@lru_cache()
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
if not _is_package_available("huggingface_hub"):
return False

if accept_dev:
return version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) >= version.parse(library_version)
else:
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
Loading