diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f994d9b08769..724e64a1a4a6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3536,15 +3536,30 @@ def save_pretrained( # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving ptrs = collections.defaultdict(list) + rank = int(os.getenv("RANK","-1")) for name, tensor in state_dict.items(): # Sometimes in the state_dict we have non-tensor objects. # e.g. in bitsandbytes we have some `str` objects in the state_dict if isinstance(tensor, torch.Tensor): - ptrs[id_tensor_storage(tensor)].append(name) + if isinstance(tensor, DTensor): + # When work under tensor parallelism, the DTensor should be restored to full tensor. + # Move the full tensor to 'cpu' since rank0 GPU memory might not large enough for large model. + # Smaller model might be fine to have a full copy on GPU, will optimize this next step. + tensor = tensor.full_tensor().to('cpu') + if rank <= 0: + state_dict[name] = tensor + ptrs[id_tensor_storage(tensor)].append(name) + else: + # If rank > 0, not needed, delete it to save memory + del tensor else: # In the non-tensor case, fall back to the pointer of the object itself ptrs[id(tensor)].append(name) + if rank >0: + del state_dict + return + # These are all the pointers of shared tensors if hasattr(self, "hf_device_map"): # if the model has offloaded parameters, we must check using find_tied_parameters()