Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use streaming transfers #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,99 @@
import urllib.request
from http.server import BaseHTTPRequestHandler
from typing import Callable, Generic, TypeVar
from dataclasses import dataclass
import pickle
from io import BufferedIOBase
from typing import Tuple
import struct

import torch

from torch.utils._pytree import tree_flatten, tree_unflatten
from hashlib import sha256
from torchft.http import _IPv6HTTPServer

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")


@dataclass
class TensorMetadata:
nbytes: int
dtype: torch.dtype
storage_offset: int
size: Tuple[int, ...]
stride: Tuple[int, ...]


def write_state_dict(state_dict: object, f: BufferedIOBase) -> None:
"""
Write the state_dict to the file-like object.
"""
values, spec = tree_flatten(state_dict)

storages = []
non_tensor_values = []
for value in values:
if isinstance(value, torch.Tensor):
storage = value.untyped_storage()
storages.append(storage)
non_tensor_values.append(
TensorMetadata(
nbytes=storage.nbytes(),
dtype=value.dtype,
storage_offset=value.storage_offset(),
size=value.size(),
stride=value.stride(),
)
)
else:
non_tensor_values.append(value)

meta_buf = pickle.dumps((non_tensor_values, spec))
checksum = sha256(meta_buf).hexdigest()
total_length = len(meta_buf) + len(checksum)

f.write(struct.pack("<q", total_length))
f.write(meta_buf)
f.write(checksum.encode("utf-8"))


for storage in storages:
storage._write_file(f, False, False, 1)


def read_state_dict(f: BufferedIOBase) -> object:
"""
Read the state_dict from the file-like object.
"""

total_length = struct.unpack("<q", f.read(8))[0]
meta_buf = f.read(total_length - 64)
checksum = f.read(64).decode("utf-8")

# Verify checksum
actual_checksum = sha256(meta_buf).hexdigest()
if checksum != actual_checksum:
raise ValueError("Checksum mismatch! Data may be corrupted.")
non_tensor_values, spec = pickle.loads(meta_buf)
values = []
for value in non_tensor_values:
if isinstance(value, TensorMetadata):
data = f.read(value.nbytes)

tensor = torch.as_strided(
torch.frombuffer(data, dtype=value.dtype),
size=value.size,
stride=value.stride,
storage_offset=value.storage_offset,
)
values.append(tensor)
else:
values.append(value)

return tree_unflatten(values, spec)

class CheckpointServer(Generic[T]):
"""
This is an HTTP server that can be used to transfer checkpoints
Expand Down Expand Up @@ -69,7 +152,7 @@ def do_GET(self):

sd = state_dict()

torch.save(sd, self.wfile)
write_state_dict(sd, self.wfile)

def err(self, msg: str) -> None:
logger.error(msg)
Expand Down Expand Up @@ -100,7 +183,8 @@ def load_from_address(cls, address: str) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
state_dict = read_state_dict(reader)
return state_dict

def address(self) -> str:
"""
Expand Down
71 changes: 69 additions & 2 deletions torchft/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import urllib.error
from unittest import TestCase
from unittest.mock import MagicMock

from torchft.checkpointing import CheckpointServer
from io import BytesIO
import torch
from typing import Tuple
from checkpointing import CheckpointServer, TensorMetadata, write_state_dict, read_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from checkpointing import CheckpointServer, TensorMetadata, write_state_dict, read_state_dict
from torchft.checkpointing import CheckpointServer, TensorMetadata, write_state_dict, read_state_dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github doesn't seem too happy from the failing unit test.



class TestCheckpointing(TestCase):
Expand All @@ -33,3 +36,67 @@ def test_checkpoint_server(self) -> None:
CheckpointServer.load_from_address(addr)

server.shutdown()

def setUp(self):
self.file = BytesIO()

def test_scalar_tensor(self):
tensor = torch.tensor(42, dtype=torch.int32)
state_dict = {'scalar': tensor}
write_state_dict(state_dict, self.file)
self.file.seek(0)

result = read_state_dict(self.file)
self.assertTrue(torch.equal(result['scalar'], tensor))

def test_strided_tensor(self):
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
strided_tensor = base_tensor[::2, ::2]
state_dict = {'strided': strided_tensor}
write_state_dict(state_dict, self.file)
self.file.seek(0)

result = read_state_dict(self.file)
self.assertTrue(torch.equal(result['strided'], strided_tensor))

def test_tensor_with_offset(self):
base_tensor = torch.arange(10, dtype=torch.float64)
offset_tensor = base_tensor[2:]
state_dict = {'offset': offset_tensor}
write_state_dict(state_dict, self.file)
self.file.seek(0)

result = read_state_dict(self.file)
self.assertTrue(torch.equal(result['offset'], offset_tensor))

def test_nested_tensors(self):
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
state_dict = {'nested': {'tensor1': tensor1, 'tensor2': tensor2}}
write_state_dict(state_dict, self.file)
self.file.seek(0)

result = read_state_dict(self.file)
self.assertTrue(torch.equal(result['nested']['tensor1'], tensor1))
self.assertTrue(torch.equal(result['nested']['tensor2'], tensor2))

def test_various_data_types(self):
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
state_dict = {
'float32': tensor_float32,
'int16': tensor_int16,
'bool': tensor_bool,
}
write_state_dict(state_dict, self.file)
self.file.seek(0)

result = read_state_dict(self.file)
self.assertTrue(torch.equal(result['float32'], tensor_float32))
self.assertTrue(torch.equal(result['int16'], tensor_int16))
self.assertTrue(torch.equal(result['bool'], tensor_bool))


if __name__ == '__main__':
unittest.main()
Loading