-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c941251
fde7277
61c9f24
5829b2e
98f3f3f
0b1ac76
8b31631
ddbe419
34fa7f8
7f84aef
f3a441d
3eb798c
dfed1ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,9 @@ | |
| from .integrations.sdpa_attention import sdpa_attention_forward | ||
| from .integrations.tensor_parallel import ( | ||
| SUPPORTED_TP_STYLES, | ||
| _get_parameter_tp_plan, | ||
| convert_local_tensor_to_dtensor, | ||
| repack_weights, | ||
| shard_and_distribute_module, | ||
| ) | ||
| from .loss.loss_utils import LOSS_MAPPING | ||
|
|
@@ -166,6 +169,8 @@ | |
| _is_ds_init_called = False | ||
| _torch_distributed_available = torch.distributed.is_available() | ||
|
|
||
| if _torch_distributed_available and is_torch_greater_or_equal("2.5"): | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
| def is_fsdp_enabled(): | ||
| return ( | ||
|
|
@@ -3483,6 +3488,9 @@ def save_pretrained( | |
| # Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model. | ||
| # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm) | ||
| state_dict = self._fix_state_dict_keys_on_save(state_dict) | ||
| # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used, | ||
| # therefore we replace them with DTensors that are equivalently sharded | ||
| state_dict = self._replace_state_dict_local_with_dtensor(state_dict) | ||
|
|
||
| if safe_serialization: | ||
| # Safetensors does not allow tensor aliasing. | ||
|
|
@@ -3491,7 +3499,7 @@ def save_pretrained( | |
| for name, tensor in state_dict.items(): | ||
| # Sometimes in the state_dict we have non-tensor objects. | ||
| # e.g. in bitsandbytes we have some `str` objects in the state_dict | ||
| if isinstance(tensor, torch.Tensor): | ||
| if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor): | ||
| ptrs[id_tensor_storage(tensor)].append(name) | ||
| else: | ||
| # In the non-tensor case, fall back to the pointer of the object itself | ||
|
|
@@ -3601,7 +3609,14 @@ def save_pretrained( | |
| for shard_file, tensors in filename_to_tensors: | ||
| shard = {} | ||
| for tensor in tensors: | ||
| shard[tensor] = state_dict[tensor].contiguous() | ||
| if isinstance(state_dict[tensor], DTensor): | ||
| full_tensor = state_dict[tensor].full_tensor() | ||
| # to get the correctly ordered tensor we need to repack if packed | ||
| if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): | ||
| full_tensor = repack_weights(full_tensor, -1, 4, 2) | ||
|
||
| shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly | ||
| else: | ||
| shard[tensor] = state_dict[tensor].contiguous() | ||
| # delete reference, see https://github.com/huggingface/transformers/pull/34890 | ||
| del state_dict[tensor] | ||
|
|
||
|
|
@@ -4530,6 +4545,7 @@ def _assign_original_dtype(module): | |
|
|
||
| # record tp degree the model sharded to | ||
| model._tp_size = tp_size | ||
| model._device_mesh = device_mesh | ||
|
|
||
| # make sure token embedding weights are still tied if needed | ||
| model.tie_weights() | ||
|
|
@@ -4717,6 +4733,18 @@ def _fix_state_dict_keys_on_save(self, state_dict): | |
| """ | ||
| return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()} | ||
|
|
||
| def _replace_state_dict_local_with_dtensor(self, state_dict): | ||
| """ | ||
| Replaces all tensors that were sharded with `local_*` strategy with DTensor to make saving possible. | ||
| """ | ||
| if not self._tp_plan: | ||
| return state_dict | ||
| # TODO: optimize this to avoid iterating over all | ||
| for key, value in state_dict.items(): | ||
| if isinstance(value, torch.Tensor) and not isinstance(value, DTensor): | ||
| state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan) | ||
|
||
| return state_dict | ||
|
|
||
| @classmethod | ||
| def _load_pretrained_model( | ||
| cls, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice this is mega useful!