Skip to content
132 changes: 126 additions & 6 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks


def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
"""
Get the TP style for a parameter from the TP plan.

The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
else:
return None


str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
Expand Down Expand Up @@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
return tensor.to(str_to_torch_dtype[slice_dtype])


def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is mega useful!


For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.

Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).

Returns:
The reordered tensor in canonical packed format.
"""

if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)

actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size

prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]

tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)

# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1

permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]

tensor_permuted = tensor_view.permute(*permute_order)

# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)

return final_ordered_tensor


def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0:
size_ = empty_param.shape[0]
Expand Down Expand Up @@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
raise ValueError(f"Unsupported parallel style value: {style}")


def convert_local_tensor_to_dtensor(
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
) -> DTensor:
"""
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
"""
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
if not tp_style:
return parameter

if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
return parameter
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
if tp_style == "local_packed_rowwise":
placements = [Shard(-1)]
elif tp_style == "local_rowwise":
if param_type == "bias":
placements = [Replicate()]
else:
placements = [Shard(-1)]
elif tp_style == "local_colwise":
if param_type == "bias":
placements = [Shard(-1)]
else:
placements = [Shard(-2)]
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)


def replace_state_dict_local_with_dtensor(
state_dict: dict[str, torch.Tensor],
tp_plan: dict[str, str],
device_mesh,
) -> dict[str, torch.Tensor]:
"""
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
"""
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
return state_dict


def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
"""
Add hooks to the module holding the layer. Meaning:
Expand Down Expand Up @@ -632,13 +756,9 @@ def shard_and_distribute_module(
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]

current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)

# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
Expand Down
29 changes: 27 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module,
verify_tp_plan,
)
Expand Down Expand Up @@ -123,6 +126,7 @@
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
Expand Down Expand Up @@ -168,6 +172,9 @@
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()

if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor


def is_fsdp_enabled():
return (
Expand Down Expand Up @@ -3413,6 +3420,12 @@ def save_pretrained(
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")

# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
raise ImportError(
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
)

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -3540,6 +3553,10 @@ def save_pretrained(
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict)
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
# therefore we replace them with DTensors that are equivalently sharded
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)

if safe_serialization:
# Safetensors does not allow tensor aliasing.
Expand All @@ -3548,7 +3565,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
Expand Down Expand Up @@ -3658,7 +3675,14 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

Expand Down Expand Up @@ -4606,6 +4630,7 @@ def _assign_original_dtype(module):

# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if is_torch_greater_or_equal_than_2_0:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes

if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
is_grokadamw_available,
is_hadamard_available,
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -542,6 +543,21 @@ def decorator(test_case):
return decorator


def require_huggingface_hub_greater_or_equal(version: str):
"""
Decorator marking a test that requires huggingface_hub version >= `version`.

These tests are skipped when huggingface_hub version is less than `version`.
"""

def decorator(test_case):
return unittest.skipUnless(
is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
)(test_case)

return decorator


def require_flash_attn(test_case):
"""
Decorator marking a test that requires Flash Attention.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
is_habana_gaudi1,
is_hadamard_available,
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_in_notebook,
is_ipex_available,
is_jieba_available,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)


@lru_cache()
def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False):
if not _is_package_available("huggingface_hub"):
return False

if accept_dev:
return version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) >= version.parse(library_version)
else:
return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
Loading