Skip to content

Commit 6fe4c8e

Browse files
authored
PGTransport: add inplace transport (#119)
1 parent 5e65330 commit 6fe4c8e

File tree

6 files changed

+253
-81
lines changed

6 files changed

+253
-81
lines changed

torchft/checkpointing/pg_transport.py

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from contextlib import contextmanager
55
from dataclasses import dataclass
66
from datetime import timedelta
7-
from typing import Generator, List, Tuple, TypeVar, Union, cast
7+
from typing import Callable, Generator, Optional, TypeVar, Union, cast
88

99
import torch
1010
from torch.distributed import Work
1111
from torch.distributed.tensor import DTensor, _DTensorSpec
12-
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
12+
from torch.utils._pytree import (
13+
KeyPath,
14+
TreeSpec,
15+
tree_flatten_with_path,
16+
tree_unflatten,
17+
)
1318

1419
from torchft.checkpointing.transport import CheckpointTransport
1520
from torchft.process_group import ProcessGroup
@@ -32,7 +37,7 @@ class _TensorMeta:
3237
shape: torch.Size
3338
dtype: torch.dtype
3439
storage_offset: int
35-
stride: Tuple[int, ...]
40+
stride: tuple[int, ...]
3641
nbytes: int
3742

3843

@@ -61,13 +66,15 @@ class _StateDictMeta:
6166
Args:
6267
step: the step of the checkpoint to verify consistency
6368
treespec: the pytree spec of the state dict
69+
paths: the path of each leaf in the state dict
6470
non_tensor_leaves: the metadata for each tensor in the state dict and any
6571
non-tensor leaves in the state dict
6672
"""
6773

6874
step: int
6975
treespec: TreeSpec
70-
non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]]
76+
paths: list[KeyPath]
77+
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]]
7178

7279

7380
@contextmanager
@@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]:
7885
logger.info(f"{name} took {dur}s")
7986

8087

81-
def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
88+
def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]:
8289
return (
8390
_cast_tensor(tensor, torch.uint8),
8491
_TensorMeta(
@@ -95,12 +102,16 @@ def _prepare_state_dict(
95102
state_dict: object,
96103
step: int,
97104
device: torch.device,
98-
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
99-
leaves, treespec = tree_flatten(state_dict)
105+
) -> tuple[_StateDictMeta, list[torch.Tensor]]:
106+
leaves: list[tuple[KeyPath, object]]
107+
leaves, treespec = tree_flatten_with_path(state_dict)
108+
109+
paths: list[KeyPath] = []
110+
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = []
111+
tensors: list[torch.Tensor] = []
112+
for key_path, v in leaves:
113+
paths.append(key_path)
100114

101-
non_tensor_leaves = []
102-
tensors = []
103-
for v in leaves:
104115
if isinstance(v, DTensor):
105116
tensor, tensor_meta = _prepare_tensor(v._local_tensor)
106117

@@ -123,6 +134,7 @@ def _prepare_state_dict(
123134
_StateDictMeta(
124135
step=step,
125136
treespec=treespec,
137+
paths=paths,
126138
non_tensor_leaves=non_tensor_leaves,
127139
),
128140
tensors,
@@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
139151
caveat that the cast tensor may be larger than the original tensor due to
140152
the differences in striding.
141153
"""
154+
assert (
155+
type(tensor) is torch.Tensor
156+
), f"can only cast standard tensors not {type(tensor)}"
142157
storage = tensor.untyped_storage()
143158
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
144159
assert ret.untyped_storage() is storage, "storage should be the same"
@@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]):
150165
This is a checkpoint transport that uses the process group to transfer checkpoints.
151166
This allows for fast recovery of workers by fetching the current weights
152167
from an existing worker.
168+
153169
Args:
154-
state_dict: a callable that returns the state dict to be transferred
170+
pg: the process group to use for communication
171+
timeout: the timeout for communication
172+
device: the device to use for tensors
173+
state_dict: if specified this function will be called to do an inplace
174+
receive into the returned state_dict. This is much faster than
175+
having to allocate new tensors and transferring them to the CPU.
155176
"""
156177

157178
def __init__(
158-
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
179+
self,
180+
pg: ProcessGroup,
181+
timeout: timedelta,
182+
device: torch.device,
183+
state_dict: Optional[Callable[[], object]] = None,
159184
) -> None:
160-
self._work: List[Work] = []
185+
self._work: list[Work] = []
161186
self._pg = pg
162187
self._timeout = timeout
163188
self._device = device
189+
self._state_dict = state_dict
164190

165191
def metadata(self) -> str:
166192
return "<n/a>"
@@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None:
169195
pass
170196

171197
def send_checkpoint(
172-
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
198+
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
173199
) -> None:
174200
with _timeit("preparing state_dict"):
175201
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
@@ -186,20 +212,29 @@ def send_checkpoint(
186212

187213
with _timeit("send tensors"):
188214
for i, t in enumerate(tensors):
215+
original_device = t.device
189216
t = t.to(self._device)
190217
for dst_rank in dst_ranks:
191218
work.append(self._pg.send([t], dst_rank, tag=3 + i))
192219

193-
# allow 3 concurrent transfers at a time to avoid OOMs
194-
while len(work) > (3 * len(dst_ranks)):
195-
work.pop(0).wait(timeout)
220+
# if we did a copy we should wait for the work to complete so we
221+
# can free the memory to avoid OOMs
222+
if original_device == torch.device("cpu"):
223+
for w in work:
224+
w.wait(timeout)
225+
work = []
196226

197227
for w in work:
198228
w.wait(timeout)
199229

200230
def recv_checkpoint(
201231
self, src_rank: int, metadata: str, step: int, timeout: timedelta
202232
) -> T:
233+
state_dict = self._state_dict() if self._state_dict else {}
234+
state_dict_leaves, _ = tree_flatten_with_path(state_dict)
235+
236+
dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves)
237+
203238
len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
204239
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
205240
length = cast(int, len_t.item())
@@ -213,18 +248,34 @@ def recv_checkpoint(
213248
assert meta.step == step
214249

215250
i: int = 0
251+
works: list[Work] = []
216252

217-
def recv(v: _TensorMeta) -> torch.Tensor:
253+
def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
218254
nonlocal i
219255

220-
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
221-
# TODO: parallelize receives
222-
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
256+
inplace = dst_tensors.get(path)
257+
if (
258+
isinstance(inplace, torch.Tensor)
259+
and inplace.device.type == self._device.type
260+
):
261+
if isinstance(inplace, DTensor):
262+
inplace = inplace._local_tensor
263+
t = _cast_tensor(inplace, torch.uint8)
264+
assert (
265+
t.nbytes == v.nbytes
266+
), "inplace tensor storage must be the same size"
267+
else:
268+
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
269+
270+
work = self._pg.recv([t], src_rank, tag=3 + i)
223271
i += 1
224272

225-
# TODO: allow in place receives to avoid having to copy to cpu to
226-
# avoid OOMs
227-
t = t.cpu()
273+
if inplace is None:
274+
# if not inplace we need to copy it to CPU to avoid OOMing
275+
work.wait(timeout)
276+
t = t.cpu()
277+
else:
278+
works.append(work)
228279

229280
return torch.as_strided(
230281
t.view(v.dtype),
@@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor:
234285
)
235286

236287
values = []
237-
for v in meta.non_tensor_leaves:
288+
for path, v in zip(meta.paths, meta.non_tensor_leaves):
238289
if isinstance(v, _TensorMeta):
239-
values.append(recv(v))
290+
values.append(recv(path, v))
240291
elif isinstance(v, _DTensorMeta):
241-
tensor = recv(v.local)
292+
tensor = recv(path, v.local)
242293
# pyre-fixme[29]: DTensor is not a function
243294
values.append(DTensor(tensor, v.spec, requires_grad=False))
244295
else:
245296
values.append(v)
246297

298+
for work in works:
299+
work.wait(timeout)
300+
247301
return tree_unflatten(values, meta.treespec)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import logging
2+
import sys
3+
from concurrent.futures import ThreadPoolExecutor
4+
from datetime import timedelta
5+
6+
import torch
7+
import torch.distributed as dist
8+
9+
from torchft.checkpointing.pg_transport import PGTransport, _timeit
10+
from torchft.process_group import ProcessGroupBabyNCCL
11+
12+
logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def main(argv: list[str]) -> None:
16+
import argparse
17+
18+
logging.basicConfig(level=logging.INFO)
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--inplace", action="store_true")
22+
parser.add_argument("--device", type=str, default="cpu")
23+
parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
24+
parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
25+
args = parser.parse_args(argv)
26+
27+
CHUNK_SIZE: int = args.chunk_size
28+
TOTAL_SIZE: int = args.total_size
29+
INPLACE: bool = args.inplace
30+
DEVICE: str = args.device
31+
32+
timeout: timedelta = timedelta(seconds=10)
33+
34+
store = dist.TCPStore(
35+
"localhost",
36+
0,
37+
is_master=True,
38+
timeout=timeout,
39+
wait_for_workers=False,
40+
)
41+
store_addr: str = f"localhost:{store.port}"
42+
43+
def run(rank: int) -> None:
44+
torch.cuda.set_device(rank)
45+
46+
device = torch.device(DEVICE)
47+
48+
with _timeit("init_pg"):
49+
pg = ProcessGroupBabyNCCL(timeout=timeout)
50+
pg.configure(store_addr=store_addr, rank=rank, world_size=2)
51+
52+
t = torch.zeros(10, device=device, dtype=torch.float32)
53+
pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)
54+
55+
with _timeit("create state_dict"):
56+
state_dict: dict[str, torch.Tensor] = {}
57+
for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
58+
state_dict[f"chunk/{i}"] = torch.zeros(
59+
CHUNK_SIZE // 4, dtype=torch.float32, device=device
60+
)
61+
62+
def get_state_dict() -> object:
63+
return state_dict
64+
65+
transport = PGTransport(
66+
pg=pg,
67+
timeout=timeout,
68+
device=device,
69+
state_dict=get_state_dict if INPLACE else None,
70+
)
71+
metadata = transport.metadata()
72+
73+
if rank == 0:
74+
with _timeit("send_checkpoint"):
75+
transport.send_checkpoint(
76+
dst_ranks=[1],
77+
step=1,
78+
state_dict=state_dict,
79+
timeout=timedelta(seconds=60),
80+
)
81+
elif rank == 1:
82+
with _timeit("recv_checkpoint"):
83+
transport.recv_checkpoint(
84+
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60)
85+
)
86+
87+
with ThreadPoolExecutor(max_workers=2) as executor:
88+
results = executor.map(run, range(2))
89+
list(results)
90+
91+
92+
if __name__ == "__main__":
93+
main(sys.argv[1:])

0 commit comments

Comments
 (0)