Skip to content
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3491,7 +3491,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all versions of torch have DTensor we need to protect this a tad bit

ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
Expand Down Expand Up @@ -3601,7 +3601,10 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor):
shard[tensor] = state_dict[tensor].full_tensor().contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine!
Wondering if we cannot also delete the tensor from the model once saved? TLDR we want the most efficient way to make sure we clean the model while saving.
Also we should check if the tensor is replicated or not, if so we don't need to get the full_tensor!

Moreover, for local plans, we need to manually gather, as the tensors are not DTensors

Copy link
Contributor Author

@S1ro1 S1ro1 May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1: we probably can do that, let me see if we get any meaningful savings from it.

Re 2: full_tensor on placements=Replicate() is a no-op (returns itself), except of some if/else checks in torch source, so I'm pretty sure there's no need to do the checks ourselves for the sake of readability. Relevant src here

Re 3: Sounds good, let me take a look at that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re 2 this is recent it seems! 3 weeks ago! wanna make sure all versions have it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That commit seems like a minor change that triggered it, commit moving DTensor to public API (which is probably the oldest one we support anyway) already has it: here

else:
shard[tensor] = state_dict[tensor].contiguous()
# delete reference, see https://github.com/huggingface/transformers/pull/34890
del state_dict[tensor]

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if is_torch_greater_or_equal_than_2_0:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes

if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand Down
53 changes: 49 additions & 4 deletions tests/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import tempfile
import textwrap
Expand All @@ -32,15 +33,18 @@
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2

def torchrun(self, script: str):
def torchrun(self, script: str, is_torchrun: bool = True):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
cmd = (
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
if is_torchrun:
cmd = (
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
else:
cmd = ["python", tmp.name]

# Note that the subprocess will be waited for here, and raise an error if not successful
try:
Expand Down Expand Up @@ -88,6 +92,47 @@ def test_model_forward(self):
)
self.torchrun(script_to_run)

def test_model_save(self):
from safetensors import safe_open

with tempfile.TemporaryDirectory() as tmp_dir:
for is_torchrun in [True, False]:
script_to_run = textwrap.dedent(
f"""
import torch
import os
from transformers import AutoModelForCausalLM

model_id = "JackFram/llama-68m"
kwargs = dict()

if os.environ.get("RANK", None) is not None:
kwargs["tp_plan"] = "auto"
result_dir = "{tmp_dir}/tp"
else:
result_dir = "{tmp_dir}/nontp"

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model.save_pretrained(result_dir)
"""
)
self.torchrun(script_to_run, is_torchrun=is_torchrun)

non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")

for filename in os.listdir(non_tp_model_path):
if not filename.endswith(".safetensors"):
continue

non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
for non_tp_key in non_tp_model.keys():
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
tp_tensor = tp_model.get_tensor(non_tp_key)
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor


@require_torch_multi_gpu
class TestTensorParallelCuda(TestTensorParallel):
Expand Down