From 8dd34122c43e4ebd192c2cbac5315838aff5f520 Mon Sep 17 00:00:00 2001 From: Xiao YU Date: Tue, 13 May 2025 12:39:32 -0400 Subject: [PATCH] Support TP for save_pretrained() --- src/transformers/modeling_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f994d9b08769..724e64a1a4a6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3536,15 +3536,30 @@ def save_pretrained( # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving ptrs = collections.defaultdict(list) + rank = int(os.getenv("RANK","-1")) for name, tensor in state_dict.items(): # Sometimes in the state_dict we have non-tensor objects. # e.g. in bitsandbytes we have some `str` objects in the state_dict if isinstance(tensor, torch.Tensor): - ptrs[id_tensor_storage(tensor)].append(name) + if isinstance(tensor, DTensor): + # When work under tensor parallelism, the DTensor should be restored to full tensor. + # Move the full tensor to 'cpu' since rank0 GPU memory might not large enough for large model. + # Smaller model might be fine to have a full copy on GPU, will optimize this next step. + tensor = tensor.full_tensor().to('cpu') + if rank <= 0: + state_dict[name] = tensor + ptrs[id_tensor_storage(tensor)].append(name) + else: + # If rank > 0, not needed, delete it to save memory + del tensor else: # In the non-tensor case, fall back to the pointer of the object itself ptrs[id(tensor)].append(name) + if rank >0: + del state_dict + return + # These are all the pointers of shared tensors if hasattr(self, "hf_device_map"): # if the model has offloaded parameters, we must check using find_tied_parameters()