Skip to content

Commit d2a43ac

Browse files
d4l3kKrishn1412
andcommitted
Use streaming transfers
Added in a new streaming_save/load implementations that don't use an extra layer of zipfile serialization Co-authored-by: Tristan Rice <[email protected]> Co-authored-by: Krishn Parasar <[email protected]>
1 parent 0c4ccf9 commit d2a43ac

File tree

4 files changed

+216
-7
lines changed

4 files changed

+216
-7
lines changed

torchft/checkpointing.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
from contextlib import contextmanager
2121
from datetime import timedelta
2222
from http.server import BaseHTTPRequestHandler
23-
from typing import Generator, Generic, List, Optional, TypeVar
24-
25-
import torch
23+
from typing import Generator, Generic, List, Optional, TypeVar, cast
2624

2725
from torchft.http import _IPv6HTTPServer
26+
from torchft.serialization import streaming_load, streaming_save
2827

2928
logger: logging.Logger = logging.getLogger(__name__)
3029

@@ -161,7 +160,7 @@ def do_GET(self):
161160

162161
state_dict = ckpt_server._state_dict
163162

164-
torch.save(state_dict, self.wfile)
163+
streaming_save(state_dict, self.wfile)
165164
except Exception as e:
166165
logger.exception(
167166
f"Exception in checkpoint server when handling {self.path=}: {e}",
@@ -198,9 +197,8 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
198197
data = f.read()
199198

200199
reader = io.BytesIO(data)
201-
# We have to set weights_only to False as there are some non-tensor
202-
# states like lr_scheduler.
203-
return torch.load(reader, weights_only=False)
200+
state_dict = streaming_load(reader)
201+
return cast(T, state_dict)
204202

205203
def address(self) -> str:
206204
"""

torchft/checkpointing_test.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import threading
8+
import unittest
89
import urllib.error
910
from datetime import timedelta
1011
from unittest import TestCase
@@ -103,3 +104,7 @@ def test_timed_acquire(self) -> None:
103104
pass
104105

105106
self.assertTrue(lock.locked())
107+
108+
109+
if __name__ == "__main__":
110+
unittest.main()

torchft/serialization.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pickle
2+
from dataclasses import dataclass
3+
from io import BufferedIOBase
4+
from typing import Dict, Tuple
5+
6+
import torch
7+
8+
9+
@dataclass
10+
class _Entry:
11+
key: str
12+
dtype: object
13+
is_storage: bool
14+
length: int
15+
16+
17+
class _InMemoryStateDict:
18+
def __init__(self) -> None:
19+
self.records: Dict[str, Tuple[object, int]] = {}
20+
21+
def write_record(self, key: str, data: object, length: int) -> None:
22+
self.records[key] = (data, length)
23+
24+
def write_to(self, f: BufferedIOBase) -> None:
25+
entries = []
26+
for key, (data, length) in self.records.items():
27+
entries.append(
28+
_Entry(
29+
key=key,
30+
is_storage=isinstance(data, torch.UntypedStorage),
31+
dtype=type(data),
32+
length=length,
33+
)
34+
)
35+
36+
pickle.dump(entries, f)
37+
38+
for key, (data, length) in self.records.items():
39+
if isinstance(data, bytes):
40+
f.write(data)
41+
elif isinstance(data, str):
42+
f.write(data.encode("utf-8"))
43+
elif isinstance(data, torch.UntypedStorage):
44+
data._write_file(f, False, False, 1)
45+
else:
46+
raise TypeError(f"unknown type: {type(data)}")
47+
48+
def read_from(self, f: BufferedIOBase) -> None:
49+
entries = pickle.load(f)
50+
51+
for entry in entries:
52+
data = f.read(entry.length)
53+
if entry.is_storage:
54+
storage = torch.frombuffer(
55+
data,
56+
dtype=torch.uint8,
57+
).untyped_storage()
58+
59+
self.records[entry.key] = (
60+
storage,
61+
entry.length,
62+
)
63+
else:
64+
self.records[entry.key] = (data, entry.length)
65+
66+
def has_record(self, key: str) -> bool:
67+
return key in self.records
68+
69+
def get_record(self, key: str) -> object:
70+
return self.records[key][0]
71+
72+
def get_storage_from_record(
73+
self, key: str, _length: int, _type: int
74+
) -> torch.Tensor:
75+
return torch.tensor(self.records[key][0], dtype=torch.uint8)
76+
77+
def serialization_id(self) -> str:
78+
return "torchft"
79+
80+
81+
def streaming_save(obj: object, f: BufferedIOBase) -> None:
82+
out = _InMemoryStateDict()
83+
torch.serialization._save(
84+
obj,
85+
zip_file=out,
86+
pickle_module=pickle,
87+
pickle_protocol=2,
88+
_disable_byteorder_record=False,
89+
)
90+
out.write_to(f)
91+
92+
93+
def streaming_load(f: BufferedIOBase) -> object:
94+
out = _InMemoryStateDict()
95+
out.read_from(f)
96+
return torch.serialization._load(
97+
zip_file=out,
98+
map_location=None,
99+
pickle_module=pickle,
100+
)

torchft/serialization_test.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from io import BytesIO
2+
from typing import cast
3+
from unittest import TestCase
4+
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor
8+
9+
from torchft.serialization import streaming_load, streaming_save
10+
11+
12+
class MyClass:
13+
def __init__(self, a: int) -> None:
14+
self.a = a
15+
16+
def __eq__(self, other: "MyClass") -> bool:
17+
return self.a == other.a
18+
19+
20+
class TestCheckpointingSerialization(TestCase):
21+
def test_scalar_tensor(self) -> None:
22+
tensor = torch.tensor(42, dtype=torch.int32)
23+
state_dict = {"scalar": tensor}
24+
file = BytesIO()
25+
streaming_save(state_dict, file)
26+
file.seek(0)
27+
28+
result = streaming_load(file)
29+
torch.testing.assert_close(result, state_dict)
30+
31+
def test_strided_tensor(self) -> None:
32+
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
33+
strided_tensor = base_tensor[::2, ::2]
34+
state_dict = {"strided": strided_tensor}
35+
file = BytesIO()
36+
streaming_save(state_dict, file)
37+
file.seek(0)
38+
39+
result = streaming_load(file)
40+
torch.testing.assert_close(result, state_dict)
41+
42+
def test_tensor_with_offset(self) -> None:
43+
base_tensor = torch.arange(10, dtype=torch.float64)
44+
offset_tensor = base_tensor[2:]
45+
state_dict = {"offset": offset_tensor}
46+
file = BytesIO()
47+
streaming_save(state_dict, file)
48+
file.seek(0)
49+
50+
result = streaming_load(file)
51+
torch.testing.assert_close(result, state_dict)
52+
53+
def test_nested_tensors(self) -> None:
54+
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
55+
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
56+
state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}}
57+
file = BytesIO()
58+
streaming_save(state_dict, file)
59+
file.seek(0)
60+
61+
result = streaming_load(file)
62+
torch.testing.assert_close(result, state_dict)
63+
64+
def test_various_data_types(self) -> None:
65+
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
66+
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
67+
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
68+
state_dict = {
69+
"float32": tensor_float32,
70+
"int16": tensor_int16,
71+
"bool": tensor_bool,
72+
}
73+
file = BytesIO()
74+
streaming_save(state_dict, file)
75+
file.seek(0)
76+
77+
result = streaming_load(file)
78+
torch.testing.assert_close(result, state_dict)
79+
80+
def test_dtensor(self) -> None:
81+
dist.init_process_group(
82+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
83+
)
84+
85+
device_mesh = DeviceMesh("cpu", 1)
86+
tensor = torch.randn(4, 4, device="cuda")
87+
dtensor = distribute_tensor(tensor, device_mesh, [])
88+
state_dict = dtensor
89+
file = BytesIO()
90+
streaming_save(state_dict, file)
91+
file.seek(0)
92+
93+
result = cast(DTensor, streaming_load(file))
94+
torch.testing.assert_close(result.to_local(), state_dict.to_local())
95+
96+
def test_python_object(self) -> None:
97+
state_dict = {
98+
"obj": MyClass(42),
99+
}
100+
101+
file = BytesIO()
102+
streaming_save(state_dict, file)
103+
file.seek(0)
104+
105+
result = streaming_load(file)
106+
self.assertEqual(result, state_dict)

0 commit comments

Comments
 (0)