Skip to content

Commit 861d2cc

Browse files
ankitageorgepytorchmergebot
authored andcommitted
Add a param for save format in Storage Writer (pytorch#150025)
Summary: add a param to specify to the storage writer how to save tensors. Write now the only options are safetensors and torch.save. Test Plan: (lintrunner) [[email protected] /data/users/ankitageorge/fbsource/fbcode/caffe2 (1d57cb27b)]$ buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/distributed/checkpoint:test_hf_storage File changed: fbcode//caffe2/torch/distributed/checkpoint/filesystem.py Buck UI: https://www.internalfb.com/buck2/e80cc963-e34a-4876-b6f4-7ce2794e48dd Test UI: https://www.internalfb.com/intern/testinfra/testrun/3659174965882569 Network: Up: 32KiB Down: 1.9KiB (reSessionID-ef9fa764-a40a-451b-ab58-08eabe7a9422) Executing actions. Remaining 0/4 3.4s exec time total Command: test. Finished 2 local Time elapsed: 19.6s Tests finished: Pass 4. Fail 0. Fatal 0. Skip 0. Build failure 0 Reviewed By: saumishr Differential Revision: D70271943 Pull Request resolved: pytorch#150025 Approved by: https://github.com/saumishr
1 parent c53bc61 commit 861d2cc

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

Diff for: torch/distributed/checkpoint/_fsspec_filesystem.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FileSystemBase,
1616
FileSystemReader,
1717
FileSystemWriter,
18+
SerializationFormat,
1819
)
1920

2021

@@ -115,6 +116,7 @@ def __init__(
115116
per_thread_copy_ahead: int = 10_000_000,
116117
overwrite: bool = True,
117118
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
119+
serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
118120
**kwargs,
119121
) -> None:
120122
"""
@@ -139,6 +141,7 @@ def __init__(
139141
per_thread_copy_ahead,
140142
overwrite=overwrite,
141143
_extensions=_extensions,
144+
serialization_format=serialization_format,
142145
)
143146
self.fs = FileSystem()
144147
self.path = self.fs.init_path(path, **kwargs)

Diff for: torch/distributed/checkpoint/_hf_storage.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_FqnToFileMapping,
1212
_HuggingFaceLoadPlanner,
1313
)
14+
from torch.distributed.checkpoint.filesystem import SerializationFormat
1415
from torch.distributed.checkpoint.metadata import (
1516
BytesStorageMetadata,
1617
Metadata,
@@ -64,7 +65,11 @@ def __init__(
6465
if HfFileSystem.protocol not in fsspec.available_protocols():
6566
fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem)
6667

67-
super().__init__(path=path, token=token)
68+
super().__init__(
69+
path=path,
70+
token=token,
71+
serialization_format=SerializationFormat.SAFETENSORS,
72+
)
6873
self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping
6974

7075
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
@@ -99,7 +104,7 @@ def write_data(
99104
(self.fs.concat_path(self.path, file_name), file_name, write_items)
100105
)
101106

102-
return super()._write_data(planner, file_queue, safe_tensors=True)
107+
return super()._write_data(planner, file_queue)
103108

104109
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
105110
metadata_to_write = {}

Diff for: torch/distributed/checkpoint/filesystem.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Generator, Iterable, Iterator, Sequence
1414
from contextlib import contextmanager
1515
from dataclasses import dataclass
16+
from enum import Enum
1617
from io import UnsupportedOperation
1718
from pathlib import Path
1819
from typing import Any, Callable, cast, IO, Optional, Union
@@ -49,7 +50,13 @@
4950
from torch.futures import Future
5051

5152

52-
__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"]
53+
__all__ = [
54+
"FileSystemWriter",
55+
"FileSystemReader",
56+
"FileSystem",
57+
"FileSystemBase",
58+
"SerializationFormat",
59+
]
5360

5461
_metadata_fn: str = ".metadata"
5562

@@ -72,6 +79,11 @@ class _StoragePrefix:
7279
prefix: str
7380

7481

82+
class SerializationFormat(Enum):
83+
TORCH_SAVE = "torch_save"
84+
SAFETENSORS = "safetensors"
85+
86+
7587
DEFAULT_SUFFIX = ".distcp"
7688

7789

@@ -298,7 +310,7 @@ def _write_item(
298310
data: Union[io.BytesIO, torch.Tensor],
299311
write_item: WriteItem,
300312
storage_key: str,
301-
safe_tensors: bool = False,
313+
serialization_format: SerializationFormat,
302314
) -> WriteResult:
303315
offset = stream.tell()
304316

@@ -312,12 +324,14 @@ def _write_item(
312324
else:
313325
assert isinstance(data, torch.Tensor)
314326
assert data.device == torch.device("cpu")
315-
if not safe_tensors:
327+
if serialization_format == SerializationFormat.TORCH_SAVE:
316328
torch.save(data, transform_to)
317329

318330
transform_to.close()
319331

320-
if not safe_tensors or isinstance(data, io.BytesIO):
332+
if serialization_format == SerializationFormat.TORCH_SAVE or isinstance(
333+
data, io.BytesIO
334+
):
321335
length = stream.tell() - offset
322336
else:
323337
length = data.numel() * data.element_size()
@@ -349,7 +363,7 @@ def _write_files_from_queue(
349363
inflight_threshhold: int,
350364
use_fsync: bool,
351365
thread_count: int,
352-
safe_tensors: bool,
366+
serialization_format: SerializationFormat,
353367
) -> None:
354368
try:
355369
while True:
@@ -397,7 +411,7 @@ def _write_files_from_queue(
397411
data,
398412
write_item,
399413
storage_key,
400-
safe_tensors,
414+
serialization_format,
401415
)
402416
)
403417

@@ -411,12 +425,12 @@ def _write_files_from_queue(
411425
tensor,
412426
write_item,
413427
storage_key,
414-
safe_tensors,
428+
serialization_format,
415429
)
416430
)
417431
tensor_dict[write_item.index.fqn] = tensor
418432

419-
if safe_tensors:
433+
if serialization_format == SerializationFormat.SAFETENSORS:
420434
from safetensors.torch import save # type: ignore[import-not-found]
421435

422436
stream.write(save(tensor_dict))
@@ -549,6 +563,7 @@ def __init__(
549563
per_thread_copy_ahead: int = 10_000_000,
550564
overwrite: bool = True,
551565
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
566+
serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
552567
*args: Any,
553568
**kwargs: Any,
554569
) -> None:
@@ -576,6 +591,7 @@ def __init__(
576591
self.save_id = _generate_uuid()
577592
self.overwrite = overwrite
578593
self.transforms = _StorageWriterTransforms(_extensions)
594+
self.serialization_format = serialization_format
579595

580596
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
581597
if checkpoint_id:
@@ -638,7 +654,6 @@ def _write_data(
638654
self,
639655
planner: SavePlanner,
640656
file_queue: queue.Queue,
641-
safe_tensors: bool = False,
642657
) -> Future[list[WriteResult]]:
643658
result_queue: queue.Queue = queue.Queue()
644659

@@ -655,7 +670,7 @@ def _write_data(
655670
self.per_thread_copy_ahead,
656671
self.sync_files,
657672
self.thread_count,
658-
safe_tensors,
673+
self.serialization_format,
659674
),
660675
)
661676
t.start()
@@ -670,7 +685,7 @@ def _write_data(
670685
inflight_threshhold=self.per_thread_copy_ahead,
671686
use_fsync=self.sync_files,
672687
thread_count=self.thread_count,
673-
safe_tensors=safe_tensors,
688+
serialization_format=self.serialization_format,
674689
)
675690

676691
for t in threads:
@@ -892,6 +907,7 @@ def __init__(
892907
cache_staged_state_dict: bool = False,
893908
overwrite: bool = True,
894909
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
910+
serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
895911
) -> None:
896912
"""
897913
Initialize the writer pointing to `path`.
@@ -919,6 +935,7 @@ def __init__(
919935
per_thread_copy_ahead=per_thread_copy_ahead,
920936
overwrite=overwrite,
921937
_extensions=_extensions,
938+
serialization_format=serialization_format,
922939
)
923940
BlockingAsyncStager.__init__(
924941
self,

0 commit comments

Comments
 (0)