Skip to content
118 changes: 112 additions & 6 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks


def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
"""
Get the TP style for a parameter from the TP plan.

The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
else:
return None


str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
Expand Down Expand Up @@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
return tensor.to(str_to_torch_dtype[slice_dtype])


def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is mega useful!


For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.

Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).

Returns:
The reordered tensor in canonical packed format.
"""

if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)

actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size

prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]

tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)

# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1

permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]

tensor_permuted = tensor_view.permute(*permute_order)

# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)

return final_ordered_tensor


def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0:
size_ = empty_param.shape[0]
Expand Down Expand Up @@ -578,6 +659,35 @@ def translate_to_torch_parallel_style(style: str):
raise ValueError(f"Unsupported parallel style value: {style}")


def convert_local_tensor_to_dtensor(
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
) -> DTensor:
"""
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
"""
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
if not tp_style:
return parameter

if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
return parameter
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
if tp_style == "local_packed_rowwise":
placements = [Shard(-1)]
elif tp_style == "local_rowwise":
if param_type == "bias":
placements = [Replicate()]
else:
placements = [Shard(-1)]
elif tp_style == "local_colwise":
if param_type == "bias":
placements = [Shard(-1)]
else:
placements = [Shard(-2)]
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)


def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
"""
Add hooks to the module holding the layer. Meaning:
Expand Down Expand Up @@ -632,13 +742,9 @@ def shard_and_distribute_module(
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]

current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)

# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
Expand Down
32 changes: 30 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
convert_local_tensor_to_dtensor,
repack_weights,
shard_and_distribute_module,
)
from .loss.loss_utils import LOSS_MAPPING
Expand Down Expand Up @@ -166,6 +169,8 @@
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()

if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor

def is_fsdp_enabled():
return (
Expand Down Expand Up @@ -3483,6 +3488,9 @@ def save_pretrained(
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict)
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
# therefore we replace them with DTensors that are equivalently sharded
state_dict = self._replace_state_dict_local_with_dtensor(state_dict)

if safe_serialization:
# Safetensors does not allow tensor aliasing.
Expand All @@ -3491,7 +3499,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
Expand Down Expand Up @@ -3601,7 +3609,14 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, 4, 2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: replace the hardcoded 4 with world_size

shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

Expand Down Expand Up @@ -4530,6 +4545,7 @@ def _assign_original_dtype(module):

# record tp degree the model sharded to
model._tp_size = tp_size
model._device_mesh = device_mesh

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -4717,6 +4733,18 @@ def _fix_state_dict_keys_on_save(self, state_dict):
"""
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}

def _replace_state_dict_local_with_dtensor(self, state_dict):
"""
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make saving possible.
"""
if not self._tp_plan:
return state_dict
# TODO: optimize this to avoid iterating over all
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we are gonna iterate over all the weights when saving anyways not sure we need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we need to do this before split_torch_state_dict_into_shards is called, as that needs to have local tensors as dtensors to properly work. We iterate to save later.

return state_dict

@classmethod
def _load_pretrained_model(
cls,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if is_torch_greater_or_equal_than_2_0:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes

if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand Down
83 changes: 79 additions & 4 deletions tests/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import tempfile
import textwrap

from transformers import is_torch_available
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
from transformers.testing_utils import (
TestCasePlus,
get_torch_dist_unique_port,
Expand All @@ -28,19 +30,51 @@
import torch


class TestTensorParallelUtils(TestCasePlus):
def test_packed_unpacked_conversion(self):
WORLD_SIZE = 2
PACKED_BLOCK_SIZE = 800
SHARDING_DIM = 2
NUM_BLOCKS = 2

original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object
empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)

class MockDeviceMesh:
def size(self):
return WORLD_SIZE

mock_mesh = (
MockDeviceMesh()
) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run

packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)

# simulate all gather of sharded weights
packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)

assert torch.allclose(unpacked_weights, original_packed_weights)


# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2

def torchrun(self, script: str):
def torchrun(self, script: str, is_torchrun: bool = True):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
cmd = (
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
if is_torchrun:
cmd = (
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
else:
cmd = ["python", tmp.name]

# Note that the subprocess will be waited for here, and raise an error if not successful
try:
Expand Down Expand Up @@ -88,6 +122,47 @@ def test_model_forward(self):
)
self.torchrun(script_to_run)

def test_model_save(self):
from safetensors import safe_open

with tempfile.TemporaryDirectory() as tmp_dir:
for is_torchrun in [True, False]:
script_to_run = textwrap.dedent(
f"""
import torch
import os
from transformers import AutoModelForCausalLM

model_id = "JackFram/llama-68m"
kwargs = dict()

if os.environ.get("RANK", None) is not None:
kwargs["tp_plan"] = "auto"
result_dir = "{tmp_dir}/tp"
else:
result_dir = "{tmp_dir}/nontp"

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model.save_pretrained(result_dir)
"""
)
self.torchrun(script_to_run, is_torchrun=is_torchrun)

non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")

for filename in os.listdir(non_tp_model_path):
if not filename.endswith(".safetensors"):
continue

non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
for non_tp_key in non_tp_model.keys():
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
tp_tensor = tp_model.get_tensor(non_tp_key)
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor


@require_torch_multi_gpu
class TestTensorParallelCuda(TestTensorParallel):
Expand Down