-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c941251
fde7277
61c9f24
5829b2e
98f3f3f
0b1ac76
8b31631
ddbe419
34fa7f8
7f84aef
f3a441d
3eb798c
dfed1ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3491,7 +3491,7 @@ def save_pretrained( | |
| for name, tensor in state_dict.items(): | ||
| # Sometimes in the state_dict we have non-tensor objects. | ||
| # e.g. in bitsandbytes we have some `str` objects in the state_dict | ||
| if isinstance(tensor, torch.Tensor): | ||
| if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor): | ||
| ptrs[id_tensor_storage(tensor)].append(name) | ||
| else: | ||
| # In the non-tensor case, fall back to the pointer of the object itself | ||
|
|
@@ -3601,7 +3601,10 @@ def save_pretrained( | |
| for shard_file, tensors in filename_to_tensors: | ||
| shard = {} | ||
| for tensor in tensors: | ||
| shard[tensor] = state_dict[tensor].contiguous() | ||
| if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor): | ||
| shard[tensor] = state_dict[tensor].full_tensor().contiguous() | ||
|
||
| else: | ||
| shard[tensor] = state_dict[tensor].contiguous() | ||
| # delete reference, see https://github.com/huggingface/transformers/pull/34890 | ||
| del state_dict[tensor] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all versions of torch have DTensor we need to protect this a tad bit