diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce4cdab8b865..60fac1f3b180 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3662,6 +3662,17 @@ def save_pretrained( if self._tp_size is not None: state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh) + # if using tensor parallel we need to gather the tensors in state dict + gathered_state_dict = {} + for key, value in state_dict.items(): + if hasattr(value, "_local_tensor"): + gathered_state_dict[key] = value.to_local().cpu() + else: + gathered_state_dict[key] = value.cpu() + + del state_dict + state_dict = gathered_state_dict + if safe_serialization: # TODO: fix safe_serialization for tied weights # Safetensors does not allow tensor aliasing.