From 65ecabd8a127982241c79d86c711d9c68654ca5b Mon Sep 17 00:00:00 2001 From: Sung Ching Liu Date: Wed, 26 Feb 2025 11:44:04 -0500 Subject: [PATCH 1/4] fix model-saving tp --- src/transformers/modeling_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce4cdab8b865..729b98873e98 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3662,6 +3662,16 @@ def save_pretrained( if self._tp_size is not None: state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh) + # if using tensor parallel we need to gather the tensors in state dict + gathered_state_dict = {} + for key, value in state_dict.items(): + if hasattr(value, "_local_tensor"): + gathered_state_dict[key] = value.full_tensor() + gathered_state_dict[key] = value + + del state_dict + state_dict = gathered_state_dict + if safe_serialization: # TODO: fix safe_serialization for tied weights # Safetensors does not allow tensor aliasing. From ae5cb307f7a5736e3b4252e87ae72108bb0adcfb Mon Sep 17 00:00:00 2001 From: Sung Ching Liu Date: Wed, 26 Feb 2025 11:46:28 -0500 Subject: [PATCH 2/4] stuff --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 729b98873e98..5cbd3ca0efc8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3666,7 +3666,7 @@ def save_pretrained( gathered_state_dict = {} for key, value in state_dict.items(): if hasattr(value, "_local_tensor"): - gathered_state_dict[key] = value.full_tensor() + gathered_state_dict[key] = value.to_local() gathered_state_dict[key] = value del state_dict From e846b1c595a3c57b7d6d38ed3815339af5b8b7d0 Mon Sep 17 00:00:00 2001 From: Sung Ching Liu Date: Wed, 26 Feb 2025 14:05:40 -0500 Subject: [PATCH 3/4] fix bugs in gathering state dict --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5cbd3ca0efc8..16bafaf7dde3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3667,7 +3667,8 @@ def save_pretrained( for key, value in state_dict.items(): if hasattr(value, "_local_tensor"): gathered_state_dict[key] = value.to_local() - gathered_state_dict[key] = value + else: + gathered_state_dict[key] = value del state_dict state_dict = gathered_state_dict From ee271a0a86cc7160331a8ca2c65a222d28262130 Mon Sep 17 00:00:00 2001 From: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:53:25 -0500 Subject: [PATCH 4/4] add .cpu() when gathering tensor --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 16bafaf7dde3..60fac1f3b180 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3666,9 +3666,9 @@ def save_pretrained( gathered_state_dict = {} for key, value in state_dict.items(): if hasattr(value, "_local_tensor"): - gathered_state_dict[key] = value.to_local() + gathered_state_dict[key] = value.to_local().cpu() else: - gathered_state_dict[key] = value + gathered_state_dict[key] = value.cpu() del state_dict state_dict = gathered_state_dict