Skip to content

Commit

Permalink
Use streaming transfers
Browse files Browse the repository at this point in the history
Added in a new streaming_save/load implementations that don't use an
extra layer of zipfile serialization

Co-authored-by: Tristan Rice <[email protected]>
Co-authored-by: Krishn Parasar <[email protected]>
  • Loading branch information
d4l3k and Krishn1412 committed Feb 5, 2025
1 parent 0c4ccf9 commit d2a43ac
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 7 deletions.
12 changes: 5 additions & 7 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from contextlib import contextmanager
from datetime import timedelta
from http.server import BaseHTTPRequestHandler
from typing import Generator, Generic, List, Optional, TypeVar

import torch
from typing import Generator, Generic, List, Optional, TypeVar, cast

from torchft.http import _IPv6HTTPServer
from torchft.serialization import streaming_load, streaming_save

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,7 +160,7 @@ def do_GET(self):

state_dict = ckpt_server._state_dict

torch.save(state_dict, self.wfile)
streaming_save(state_dict, self.wfile)
except Exception as e:
logger.exception(
f"Exception in checkpoint server when handling {self.path=}: {e}",
Expand Down Expand Up @@ -198,9 +197,8 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
# We have to set weights_only to False as there are some non-tensor
# states like lr_scheduler.
return torch.load(reader, weights_only=False)
state_dict = streaming_load(reader)
return cast(T, state_dict)

def address(self) -> str:
"""
Expand Down
5 changes: 5 additions & 0 deletions torchft/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import threading
import unittest
import urllib.error
from datetime import timedelta
from unittest import TestCase
Expand Down Expand Up @@ -103,3 +104,7 @@ def test_timed_acquire(self) -> None:
pass

self.assertTrue(lock.locked())


if __name__ == "__main__":
unittest.main()
100 changes: 100 additions & 0 deletions torchft/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pickle
from dataclasses import dataclass
from io import BufferedIOBase
from typing import Dict, Tuple

import torch


@dataclass
class _Entry:
key: str
dtype: object
is_storage: bool
length: int


class _InMemoryStateDict:
def __init__(self) -> None:
self.records: Dict[str, Tuple[object, int]] = {}

def write_record(self, key: str, data: object, length: int) -> None:
self.records[key] = (data, length)

def write_to(self, f: BufferedIOBase) -> None:
entries = []
for key, (data, length) in self.records.items():
entries.append(
_Entry(
key=key,
is_storage=isinstance(data, torch.UntypedStorage),
dtype=type(data),
length=length,
)
)

pickle.dump(entries, f)

for key, (data, length) in self.records.items():
if isinstance(data, bytes):
f.write(data)
elif isinstance(data, str):
f.write(data.encode("utf-8"))
elif isinstance(data, torch.UntypedStorage):
data._write_file(f, False, False, 1)
else:
raise TypeError(f"unknown type: {type(data)}")

def read_from(self, f: BufferedIOBase) -> None:
entries = pickle.load(f)

for entry in entries:
data = f.read(entry.length)
if entry.is_storage:
storage = torch.frombuffer(
data,
dtype=torch.uint8,
).untyped_storage()

self.records[entry.key] = (
storage,
entry.length,
)
else:
self.records[entry.key] = (data, entry.length)

def has_record(self, key: str) -> bool:
return key in self.records

def get_record(self, key: str) -> object:
return self.records[key][0]

def get_storage_from_record(
self, key: str, _length: int, _type: int
) -> torch.Tensor:
return torch.tensor(self.records[key][0], dtype=torch.uint8)

def serialization_id(self) -> str:
return "torchft"


def streaming_save(obj: object, f: BufferedIOBase) -> None:
out = _InMemoryStateDict()
torch.serialization._save(
obj,
zip_file=out,
pickle_module=pickle,
pickle_protocol=2,
_disable_byteorder_record=False,
)
out.write_to(f)


def streaming_load(f: BufferedIOBase) -> object:
out = _InMemoryStateDict()
out.read_from(f)
return torch.serialization._load(
zip_file=out,
map_location=None,
pickle_module=pickle,
)
106 changes: 106 additions & 0 deletions torchft/serialization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from io import BytesIO
from typing import cast
from unittest import TestCase

import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor

from torchft.serialization import streaming_load, streaming_save


class MyClass:
def __init__(self, a: int) -> None:
self.a = a

def __eq__(self, other: "MyClass") -> bool:
return self.a == other.a


class TestCheckpointingSerialization(TestCase):
def test_scalar_tensor(self) -> None:
tensor = torch.tensor(42, dtype=torch.int32)
state_dict = {"scalar": tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_strided_tensor(self) -> None:
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
strided_tensor = base_tensor[::2, ::2]
state_dict = {"strided": strided_tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_tensor_with_offset(self) -> None:
base_tensor = torch.arange(10, dtype=torch.float64)
offset_tensor = base_tensor[2:]
state_dict = {"offset": offset_tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_nested_tensors(self) -> None:
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_various_data_types(self) -> None:
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
state_dict = {
"float32": tensor_float32,
"int16": tensor_int16,
"bool": tensor_bool,
}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_dtensor(self) -> None:
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
)

device_mesh = DeviceMesh("cpu", 1)
tensor = torch.randn(4, 4, device="cuda")
dtensor = distribute_tensor(tensor, device_mesh, [])
state_dict = dtensor
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = cast(DTensor, streaming_load(file))
torch.testing.assert_close(result.to_local(), state_dict.to_local())

def test_python_object(self) -> None:
state_dict = {
"obj": MyClass(42),
}

file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
self.assertEqual(result, state_dict)

0 comments on commit d2a43ac

Please sign in to comment.