Skip to content

Commit 8628a3f

Browse files
authored
checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints (#110)
1 parent ca8c540 commit 8628a3f

File tree

5 files changed

+490
-13
lines changed

5 files changed

+490
-13
lines changed

torchft/checkpointing/http_transport_test.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,23 @@
66

77
import urllib.error
88
from datetime import timedelta
9-
from typing import Any, Dict
10-
from unittest import TestCase
9+
from typing import Dict
10+
from unittest import TestCase, skipUnless
1111
from unittest.mock import MagicMock
1212

1313
import torch
1414
from parameterized import parameterized
1515

1616
from torchft.checkpointing.http_transport import HTTPTransport
1717
from torchft.checkpointing.http_transport_bench import main as bench_main
18+
from torchft.checkpointing.transport import CheckpointTransport
19+
from torchft.checkpointing.transport_test import (
20+
assertStateDictEqual,
21+
run_multi_recovery_test,
22+
)
1823

1924

2025
class TestHTTPTransport(TestCase):
21-
def assertStateDictEqual(self, a: Dict[str, object], b: Dict[str, object]) -> None:
22-
for k, v1 in a.items():
23-
v2 = b[k]
24-
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
25-
torch.testing.assert_close(v1.cpu(), v2.cpu())
26-
else:
27-
self.assertEqual(v1, v2)
28-
2926
@parameterized.expand(
3027
[
3128
("no chunks", 0),
@@ -59,7 +56,7 @@ def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
5956
out = server.recv_checkpoint(
6057
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
6158
)
62-
self.assertStateDictEqual(out, expected)
59+
assertStateDictEqual(self, out, expected)
6360

6461
# test timeout
6562
with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
@@ -114,6 +111,30 @@ def test_checkpoint_server_locking(self) -> None:
114111

115112
server.shutdown()
116113

114+
def test_multi_http_transport_cpu(self) -> None:
115+
device = torch.device("cpu")
116+
117+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
118+
return HTTPTransport(
119+
timeout=timedelta(seconds=10),
120+
num_chunks=0,
121+
)
122+
123+
run_multi_recovery_test(self, init, device=device)
124+
125+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
126+
@skipUnless(torch.cuda.is_available(), "CUDA is not available")
127+
def test_multi_http_transport_cuda(self) -> None:
128+
device = torch.device("cuda")
129+
130+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
131+
return HTTPTransport(
132+
timeout=timedelta(seconds=10),
133+
num_chunks=0,
134+
)
135+
136+
run_multi_recovery_test(self, init, device=device)
137+
117138
def test_benchmark(self) -> None:
118139
bench_main(
119140
[

torchft/checkpointing/pg_transport.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import logging
2+
import pickle
3+
import time
4+
from contextlib import contextmanager
5+
from dataclasses import dataclass
6+
from datetime import timedelta
7+
from typing import Generator, List, Tuple, TypeVar, Union, cast
8+
9+
import torch
10+
from torch.distributed import Work
11+
from torch.distributed.tensor import DTensor, _DTensorSpec
12+
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
13+
14+
from torchft.checkpointing.transport import CheckpointTransport
15+
from torchft.process_group import ProcessGroup
16+
17+
logger: logging.Logger = logging.getLogger(__name__)
18+
19+
T = TypeVar("T")
20+
21+
22+
@dataclass
23+
class _TensorMeta:
24+
"""
25+
This is the metadata for a tensor that is used to transfer checkpoints.
26+
It contains the shape, the dtype, the storage offset and the stride of the
27+
tensor.
28+
29+
This must be pickleable so that it can be sent over the wire.
30+
"""
31+
32+
shape: torch.Size
33+
dtype: torch.dtype
34+
storage_offset: int
35+
stride: Tuple[int, ...]
36+
nbytes: int
37+
38+
39+
@dataclass
40+
class _DTensorMeta:
41+
"""
42+
This is the metadata for a DTensor that is used to transfer checkpoints.
43+
It contains the metadata for the local tensor and the spec of the DTensor.
44+
45+
This must be pickleable so that it can be sent over the wire.
46+
"""
47+
48+
local: _TensorMeta
49+
spec: _DTensorSpec
50+
51+
52+
@dataclass
53+
class _StateDictMeta:
54+
"""
55+
This is the metadata for a state dict that is used to transfer checkpoints.
56+
It contains the step, the pytree spec of the state dict and the metadata for
57+
each tensor in the state dict.
58+
59+
This must be pickleable so that it can be sent over the wire.
60+
61+
Args:
62+
step: the step of the checkpoint to verify consistency
63+
treespec: the pytree spec of the state dict
64+
non_tensor_leaves: the metadata for each tensor in the state dict and any
65+
non-tensor leaves in the state dict
66+
"""
67+
68+
step: int
69+
treespec: TreeSpec
70+
non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]]
71+
72+
73+
@contextmanager
74+
def _timeit(name: str) -> Generator[None, None, None]:
75+
start = time.perf_counter()
76+
yield
77+
dur = time.perf_counter() - start
78+
logger.info(f"{name} took {dur}s")
79+
80+
81+
def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
82+
return (
83+
_cast_tensor(tensor, torch.uint8),
84+
_TensorMeta(
85+
shape=tensor.shape,
86+
dtype=tensor.dtype,
87+
storage_offset=cast(int, tensor.storage_offset()),
88+
stride=tensor.stride(),
89+
nbytes=tensor.untyped_storage().nbytes(),
90+
),
91+
)
92+
93+
94+
def _prepare_state_dict(
95+
state_dict: object,
96+
step: int,
97+
device: torch.device,
98+
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
99+
leaves, treespec = tree_flatten(state_dict)
100+
101+
non_tensor_leaves = []
102+
tensors = []
103+
for v in leaves:
104+
if isinstance(v, DTensor):
105+
tensor, tensor_meta = _prepare_tensor(v._local_tensor)
106+
107+
tensors.append(tensor)
108+
109+
non_tensor_leaves.append(
110+
_DTensorMeta(
111+
local=tensor_meta,
112+
spec=v._spec,
113+
)
114+
)
115+
elif isinstance(v, torch.Tensor):
116+
tensor, tensor_meta = _prepare_tensor(v)
117+
tensors.append(tensor)
118+
non_tensor_leaves.append(tensor_meta)
119+
else:
120+
non_tensor_leaves.append(v)
121+
122+
return (
123+
_StateDictMeta(
124+
step=step,
125+
treespec=treespec,
126+
non_tensor_leaves=non_tensor_leaves,
127+
),
128+
tensors,
129+
)
130+
131+
132+
def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
133+
"""
134+
Casts the underlying storage to a tensor of the given dtype.
135+
136+
The returned tensor will be of size ``storage.nbytes``.
137+
138+
This works for all datatypes and supports strided/offset tensors with the
139+
caveat that the cast tensor may be larger than the original tensor due to
140+
the differences in striding.
141+
"""
142+
storage = tensor.untyped_storage()
143+
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
144+
assert ret.untyped_storage() is storage, "storage should be the same"
145+
return ret
146+
147+
148+
class PGTransport(CheckpointTransport[T]):
149+
"""
150+
This is a checkpoint transport that uses the process group to transfer checkpoints.
151+
This allows for fast recovery of workers by fetching the current weights
152+
from an existing worker.
153+
Args:
154+
state_dict: a callable that returns the state dict to be transferred
155+
"""
156+
157+
def __init__(
158+
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
159+
) -> None:
160+
self._work: List[Work] = []
161+
self._pg = pg
162+
self._timeout = timeout
163+
self._device = device
164+
165+
def metadata(self) -> str:
166+
return "<n/a>"
167+
168+
def disallow_checkpoint(self) -> None:
169+
pass
170+
171+
def send_checkpoint(
172+
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
173+
) -> None:
174+
with _timeit("preparing state_dict"):
175+
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
176+
177+
work = []
178+
179+
with _timeit("send pickle"):
180+
buf = pickle.dumps(meta)
181+
len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device)
182+
buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device)
183+
for dst_rank in dst_ranks:
184+
work.append(self._pg.send([len_t], dst_rank, tag=1))
185+
work.append(self._pg.send([buf_t], dst_rank, tag=2))
186+
187+
with _timeit("send tensors"):
188+
for i, t in enumerate(tensors):
189+
t = t.to(self._device)
190+
for dst_rank in dst_ranks:
191+
work.append(self._pg.send([t], dst_rank, tag=3 + i))
192+
193+
# allow 3 concurrent transfers at a time to avoid OOMs
194+
while len(work) > (3 * len(dst_ranks)):
195+
work.pop(0).wait(timeout)
196+
197+
for w in work:
198+
w.wait(timeout)
199+
200+
def recv_checkpoint(
201+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
202+
) -> T:
203+
len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
204+
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
205+
length = cast(int, len_t.item())
206+
207+
assert length > 0, f"invalid metadata length {length=}"
208+
209+
buf = torch.empty(length, dtype=torch.uint8, device=self._device)
210+
self._pg.recv([buf], src_rank, tag=2).wait(timeout)
211+
212+
meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes())
213+
assert meta.step == step
214+
215+
i: int = 0
216+
217+
def recv(v: _TensorMeta) -> torch.Tensor:
218+
nonlocal i
219+
220+
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
221+
# TODO: parallelize receives
222+
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
223+
i += 1
224+
225+
# TODO: allow in place receives to avoid having to copy to cpu to
226+
# avoid OOMs
227+
t = t.cpu()
228+
229+
return torch.as_strided(
230+
t.view(v.dtype),
231+
size=v.shape,
232+
stride=v.stride,
233+
storage_offset=v.storage_offset,
234+
)
235+
236+
values = []
237+
for v in meta.non_tensor_leaves:
238+
if isinstance(v, _TensorMeta):
239+
values.append(recv(v))
240+
elif isinstance(v, _DTensorMeta):
241+
tensor = recv(v.local)
242+
# pyre-fixme[29]: DTensor is not a function
243+
values.append(DTensor(tensor, v.spec, requires_grad=False))
244+
else:
245+
values.append(v)
246+
247+
return tree_unflatten(values, meta.treespec)
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from datetime import timedelta
2+
from typing import Dict
3+
from unittest import TestCase, skipUnless
4+
5+
import torch
6+
from torch.distributed import TCPStore
7+
8+
from torchft.checkpointing.pg_transport import PGTransport
9+
from torchft.checkpointing.transport import CheckpointTransport
10+
from torchft.checkpointing.transport_test import run_multi_recovery_test
11+
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
12+
13+
14+
class PGTransportTest(TestCase):
15+
def test_pg_transport_gloo(self) -> None:
16+
store: TCPStore = TCPStore(
17+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
18+
)
19+
device: torch.device = torch.device("cpu")
20+
21+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
22+
pg = ProcessGroupGloo()
23+
pg.configure(
24+
store_addr=f"localhost:{store.port}/prefix",
25+
rank=rank,
26+
world_size=world_size,
27+
)
28+
29+
return PGTransport[Dict[str, object]](
30+
pg, timeout=timedelta(seconds=10), device=device
31+
)
32+
33+
run_multi_recovery_test(self, init, device=device)
34+
35+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
36+
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
37+
def test_pg_transport_baby_nccl(self) -> None:
38+
store: TCPStore = TCPStore(
39+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
40+
)
41+
device: torch.device = torch.device("cuda")
42+
43+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
44+
torch.cuda.set_device(rank)
45+
46+
pg = ProcessGroupBabyNCCL()
47+
pg.configure(
48+
store_addr=f"localhost:{store.port}/prefix",
49+
rank=rank,
50+
world_size=world_size,
51+
)
52+
53+
return PGTransport[Dict[str, object]](
54+
pg, timeout=timedelta(seconds=10), device=device
55+
)
56+
57+
run_multi_recovery_test(self, init, device=device)

0 commit comments

Comments
 (0)