Skip to content

Commit 4d0622e

Browse files
committed
Use streaming transfers
1) Added the the write_state_dict and read_state_dict implementations into checkpointing.py 2) Replaced existing torch.save/torch.load with those 3)Added unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors 4) Added checksum to read/write_state_dict that uses zlib.crc32
1 parent c0acce1 commit 4d0622e

File tree

2 files changed

+156
-5
lines changed

2 files changed

+156
-5
lines changed

torchft/checkpointing.py

+87-3
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,99 @@
1818
import urllib.request
1919
from http.server import BaseHTTPRequestHandler
2020
from typing import Callable, Generic, TypeVar
21+
from dataclasses import dataclass
22+
import pickle
23+
from io import BufferedIOBase
24+
from typing import Tuple
25+
import struct
2126

2227
import torch
23-
28+
from torch.utils._pytree import tree_flatten, tree_unflatten
29+
from hashlib import sha256
2430
from torchft.http import _IPv6HTTPServer
2531

2632
logger: logging.Logger = logging.getLogger(__name__)
2733

2834
T = TypeVar("T")
2935

3036

37+
@dataclass
38+
class TensorMetadata:
39+
nbytes: int
40+
dtype: torch.dtype
41+
storage_offset: int
42+
size: Tuple[int, ...]
43+
stride: Tuple[int, ...]
44+
45+
46+
def write_state_dict(state_dict: object, f: BufferedIOBase) -> None:
47+
"""
48+
Write the state_dict to the file-like object.
49+
"""
50+
values, spec = tree_flatten(state_dict)
51+
52+
storages = []
53+
non_tensor_values = []
54+
for value in values:
55+
if isinstance(value, torch.Tensor):
56+
storage = value.untyped_storage()
57+
storages.append(storage)
58+
non_tensor_values.append(
59+
TensorMetadata(
60+
nbytes=storage.nbytes(),
61+
dtype=value.dtype,
62+
storage_offset=value.storage_offset(),
63+
size=value.size(),
64+
stride=value.stride(),
65+
)
66+
)
67+
else:
68+
non_tensor_values.append(value)
69+
70+
meta_buf = pickle.dumps((non_tensor_values, spec))
71+
checksum = sha256(meta_buf).hexdigest()
72+
total_length = len(meta_buf) + len(checksum)
73+
74+
f.write(struct.pack("<q", total_length))
75+
f.write(meta_buf)
76+
f.write(checksum.encode("utf-8"))
77+
78+
79+
for storage in storages:
80+
storage._write_file(f, False, False, 1)
81+
82+
83+
def read_state_dict(f: BufferedIOBase) -> object:
84+
"""
85+
Read the state_dict from the file-like object.
86+
"""
87+
88+
total_length = struct.unpack("<q", f.read(8))[0]
89+
meta_buf = f.read(total_length - 64)
90+
checksum = f.read(64).decode("utf-8")
91+
92+
# Verify checksum
93+
actual_checksum = sha256(meta_buf).hexdigest()
94+
if checksum != actual_checksum:
95+
raise ValueError("Checksum mismatch! Data may be corrupted.")
96+
non_tensor_values, spec = pickle.loads(meta_buf)
97+
values = []
98+
for value in non_tensor_values:
99+
if isinstance(value, TensorMetadata):
100+
data = f.read(value.nbytes)
101+
102+
tensor = torch.as_strided(
103+
torch.frombuffer(data, dtype=value.dtype),
104+
size=value.size,
105+
stride=value.stride,
106+
storage_offset=value.storage_offset,
107+
)
108+
values.append(tensor)
109+
else:
110+
values.append(value)
111+
112+
return tree_unflatten(values, spec)
113+
31114
class CheckpointServer(Generic[T]):
32115
"""
33116
This is an HTTP server that can be used to transfer checkpoints
@@ -69,7 +152,7 @@ def do_GET(self):
69152

70153
sd = state_dict()
71154

72-
torch.save(sd, self.wfile)
155+
write_state_dict(sd, self.wfile)
73156

74157
def err(self, msg: str) -> None:
75158
logger.error(msg)
@@ -100,7 +183,8 @@ def load_from_address(cls, address: str) -> T:
100183
data = f.read()
101184

102185
reader = io.BytesIO(data)
103-
return torch.load(reader, weights_only=True)
186+
state_dict = read_state_dict(reader)
187+
return state_dict
104188

105189
def address(self) -> str:
106190
"""

torchft/checkpointing_test.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import unittest
78
import urllib.error
89
from unittest import TestCase
910
from unittest.mock import MagicMock
10-
11-
from torchft.checkpointing import CheckpointServer
11+
from io import BytesIO
12+
import torch
13+
from typing import Tuple
14+
from checkpointing import CheckpointServer, TensorMetadata, write_state_dict, read_state_dict
1215

1316

1417
class TestCheckpointing(TestCase):
@@ -33,3 +36,67 @@ def test_checkpoint_server(self) -> None:
3336
CheckpointServer.load_from_address(addr)
3437

3538
server.shutdown()
39+
40+
def setUp(self):
41+
self.file = BytesIO()
42+
43+
def test_scalar_tensor(self):
44+
tensor = torch.tensor(42, dtype=torch.int32)
45+
state_dict = {'scalar': tensor}
46+
write_state_dict(state_dict, self.file)
47+
self.file.seek(0)
48+
49+
result = read_state_dict(self.file)
50+
self.assertTrue(torch.equal(result['scalar'], tensor))
51+
52+
def test_strided_tensor(self):
53+
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
54+
strided_tensor = base_tensor[::2, ::2]
55+
state_dict = {'strided': strided_tensor}
56+
write_state_dict(state_dict, self.file)
57+
self.file.seek(0)
58+
59+
result = read_state_dict(self.file)
60+
self.assertTrue(torch.equal(result['strided'], strided_tensor))
61+
62+
def test_tensor_with_offset(self):
63+
base_tensor = torch.arange(10, dtype=torch.float64)
64+
offset_tensor = base_tensor[2:]
65+
state_dict = {'offset': offset_tensor}
66+
write_state_dict(state_dict, self.file)
67+
self.file.seek(0)
68+
69+
result = read_state_dict(self.file)
70+
self.assertTrue(torch.equal(result['offset'], offset_tensor))
71+
72+
def test_nested_tensors(self):
73+
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
74+
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
75+
state_dict = {'nested': {'tensor1': tensor1, 'tensor2': tensor2}}
76+
write_state_dict(state_dict, self.file)
77+
self.file.seek(0)
78+
79+
result = read_state_dict(self.file)
80+
self.assertTrue(torch.equal(result['nested']['tensor1'], tensor1))
81+
self.assertTrue(torch.equal(result['nested']['tensor2'], tensor2))
82+
83+
def test_various_data_types(self):
84+
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
85+
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
86+
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
87+
state_dict = {
88+
'float32': tensor_float32,
89+
'int16': tensor_int16,
90+
'bool': tensor_bool,
91+
}
92+
write_state_dict(state_dict, self.file)
93+
self.file.seek(0)
94+
95+
result = read_state_dict(self.file)
96+
self.assertTrue(torch.equal(result['float32'], tensor_float32))
97+
self.assertTrue(torch.equal(result['int16'], tensor_int16))
98+
self.assertTrue(torch.equal(result['bool'], tensor_bool))
99+
100+
101+
if __name__ == '__main__':
102+
unittest.main()

0 commit comments

Comments
 (0)