-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added in a new streaming_save/load implementations that don't use an extra layer of zipfile serialization Co-authored-by: Tristan Rice <[email protected]> Co-authored-by: Krishn Parasar <[email protected]>
- Loading branch information
1 parent
0c4ccf9
commit d2a43ac
Showing
4 changed files
with
216 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import pickle | ||
from dataclasses import dataclass | ||
from io import BufferedIOBase | ||
from typing import Dict, Tuple | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class _Entry: | ||
key: str | ||
dtype: object | ||
is_storage: bool | ||
length: int | ||
|
||
|
||
class _InMemoryStateDict: | ||
def __init__(self) -> None: | ||
self.records: Dict[str, Tuple[object, int]] = {} | ||
|
||
def write_record(self, key: str, data: object, length: int) -> None: | ||
self.records[key] = (data, length) | ||
|
||
def write_to(self, f: BufferedIOBase) -> None: | ||
entries = [] | ||
for key, (data, length) in self.records.items(): | ||
entries.append( | ||
_Entry( | ||
key=key, | ||
is_storage=isinstance(data, torch.UntypedStorage), | ||
dtype=type(data), | ||
length=length, | ||
) | ||
) | ||
|
||
pickle.dump(entries, f) | ||
|
||
for key, (data, length) in self.records.items(): | ||
if isinstance(data, bytes): | ||
f.write(data) | ||
elif isinstance(data, str): | ||
f.write(data.encode("utf-8")) | ||
elif isinstance(data, torch.UntypedStorage): | ||
data._write_file(f, False, False, 1) | ||
else: | ||
raise TypeError(f"unknown type: {type(data)}") | ||
|
||
def read_from(self, f: BufferedIOBase) -> None: | ||
entries = pickle.load(f) | ||
|
||
for entry in entries: | ||
data = f.read(entry.length) | ||
if entry.is_storage: | ||
storage = torch.frombuffer( | ||
data, | ||
dtype=torch.uint8, | ||
).untyped_storage() | ||
|
||
self.records[entry.key] = ( | ||
storage, | ||
entry.length, | ||
) | ||
else: | ||
self.records[entry.key] = (data, entry.length) | ||
|
||
def has_record(self, key: str) -> bool: | ||
return key in self.records | ||
|
||
def get_record(self, key: str) -> object: | ||
return self.records[key][0] | ||
|
||
def get_storage_from_record( | ||
self, key: str, _length: int, _type: int | ||
) -> torch.Tensor: | ||
return torch.tensor(self.records[key][0], dtype=torch.uint8) | ||
|
||
def serialization_id(self) -> str: | ||
return "torchft" | ||
|
||
|
||
def streaming_save(obj: object, f: BufferedIOBase) -> None: | ||
out = _InMemoryStateDict() | ||
torch.serialization._save( | ||
obj, | ||
zip_file=out, | ||
pickle_module=pickle, | ||
pickle_protocol=2, | ||
_disable_byteorder_record=False, | ||
) | ||
out.write_to(f) | ||
|
||
|
||
def streaming_load(f: BufferedIOBase) -> object: | ||
out = _InMemoryStateDict() | ||
out.read_from(f) | ||
return torch.serialization._load( | ||
zip_file=out, | ||
map_location=None, | ||
pickle_module=pickle, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from io import BytesIO | ||
from typing import cast | ||
from unittest import TestCase | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor | ||
|
||
from torchft.serialization import streaming_load, streaming_save | ||
|
||
|
||
class MyClass: | ||
def __init__(self, a: int) -> None: | ||
self.a = a | ||
|
||
def __eq__(self, other: "MyClass") -> bool: | ||
return self.a == other.a | ||
|
||
|
||
class TestCheckpointingSerialization(TestCase): | ||
def test_scalar_tensor(self) -> None: | ||
tensor = torch.tensor(42, dtype=torch.int32) | ||
state_dict = {"scalar": tensor} | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
torch.testing.assert_close(result, state_dict) | ||
|
||
def test_strided_tensor(self) -> None: | ||
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4) | ||
strided_tensor = base_tensor[::2, ::2] | ||
state_dict = {"strided": strided_tensor} | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
torch.testing.assert_close(result, state_dict) | ||
|
||
def test_tensor_with_offset(self) -> None: | ||
base_tensor = torch.arange(10, dtype=torch.float64) | ||
offset_tensor = base_tensor[2:] | ||
state_dict = {"offset": offset_tensor} | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
torch.testing.assert_close(result, state_dict) | ||
|
||
def test_nested_tensors(self) -> None: | ||
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32) | ||
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64) | ||
state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}} | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
torch.testing.assert_close(result, state_dict) | ||
|
||
def test_various_data_types(self) -> None: | ||
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) | ||
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16) | ||
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool) | ||
state_dict = { | ||
"float32": tensor_float32, | ||
"int16": tensor_int16, | ||
"bool": tensor_bool, | ||
} | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
torch.testing.assert_close(result, state_dict) | ||
|
||
def test_dtensor(self) -> None: | ||
dist.init_process_group( | ||
backend="gloo", rank=0, world_size=1, store=dist.HashStore() | ||
) | ||
|
||
device_mesh = DeviceMesh("cpu", 1) | ||
tensor = torch.randn(4, 4, device="cuda") | ||
dtensor = distribute_tensor(tensor, device_mesh, []) | ||
state_dict = dtensor | ||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = cast(DTensor, streaming_load(file)) | ||
torch.testing.assert_close(result.to_local(), state_dict.to_local()) | ||
|
||
def test_python_object(self) -> None: | ||
state_dict = { | ||
"obj": MyClass(42), | ||
} | ||
|
||
file = BytesIO() | ||
streaming_save(state_dict, file) | ||
file.seek(0) | ||
|
||
result = streaming_load(file) | ||
self.assertEqual(result, state_dict) |