4
4
from contextlib import contextmanager
5
5
from dataclasses import dataclass
6
6
from datetime import timedelta
7
- from typing import Generator , List , Tuple , TypeVar , Union , cast
7
+ from typing import Callable , Generator , Optional , TypeVar , Union , cast
8
8
9
9
import torch
10
10
from torch .distributed import Work
11
11
from torch .distributed .tensor import DTensor , _DTensorSpec
12
- from torch .utils ._pytree import TreeSpec , tree_flatten , tree_unflatten
12
+ from torch .utils ._pytree import (
13
+ KeyPath ,
14
+ TreeSpec ,
15
+ tree_flatten_with_path ,
16
+ tree_unflatten ,
17
+ )
13
18
14
19
from torchft .checkpointing .transport import CheckpointTransport
15
20
from torchft .process_group import ProcessGroup
@@ -32,7 +37,7 @@ class _TensorMeta:
32
37
shape : torch .Size
33
38
dtype : torch .dtype
34
39
storage_offset : int
35
- stride : Tuple [int , ...]
40
+ stride : tuple [int , ...]
36
41
nbytes : int
37
42
38
43
@@ -61,13 +66,15 @@ class _StateDictMeta:
61
66
Args:
62
67
step: the step of the checkpoint to verify consistency
63
68
treespec: the pytree spec of the state dict
69
+ paths: the path of each leaf in the state dict
64
70
non_tensor_leaves: the metadata for each tensor in the state dict and any
65
71
non-tensor leaves in the state dict
66
72
"""
67
73
68
74
step : int
69
75
treespec : TreeSpec
70
- non_tensor_leaves : List [Union [object , _TensorMeta , _DTensorMeta ]]
76
+ paths : list [KeyPath ]
77
+ non_tensor_leaves : list [Union [object , _TensorMeta , _DTensorMeta ]]
71
78
72
79
73
80
@contextmanager
@@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]:
78
85
logger .info (f"{ name } took { dur } s" )
79
86
80
87
81
- def _prepare_tensor (tensor : torch .Tensor ) -> Tuple [torch .Tensor , _TensorMeta ]:
88
+ def _prepare_tensor (tensor : torch .Tensor ) -> tuple [torch .Tensor , _TensorMeta ]:
82
89
return (
83
90
_cast_tensor (tensor , torch .uint8 ),
84
91
_TensorMeta (
@@ -95,12 +102,16 @@ def _prepare_state_dict(
95
102
state_dict : object ,
96
103
step : int ,
97
104
device : torch .device ,
98
- ) -> Tuple [_StateDictMeta , List [torch .Tensor ]]:
99
- leaves , treespec = tree_flatten (state_dict )
105
+ ) -> tuple [_StateDictMeta , list [torch .Tensor ]]:
106
+ leaves : list [tuple [KeyPath , object ]]
107
+ leaves , treespec = tree_flatten_with_path (state_dict )
108
+
109
+ paths : list [KeyPath ] = []
110
+ non_tensor_leaves : list [Union [object , _TensorMeta , _DTensorMeta ]] = []
111
+ tensors : list [torch .Tensor ] = []
112
+ for key_path , v in leaves :
113
+ paths .append (key_path )
100
114
101
- non_tensor_leaves = []
102
- tensors = []
103
- for v in leaves :
104
115
if isinstance (v , DTensor ):
105
116
tensor , tensor_meta = _prepare_tensor (v ._local_tensor )
106
117
@@ -123,6 +134,7 @@ def _prepare_state_dict(
123
134
_StateDictMeta (
124
135
step = step ,
125
136
treespec = treespec ,
137
+ paths = paths ,
126
138
non_tensor_leaves = non_tensor_leaves ,
127
139
),
128
140
tensors ,
@@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
139
151
caveat that the cast tensor may be larger than the original tensor due to
140
152
the differences in striding.
141
153
"""
154
+ assert (
155
+ type (tensor ) is torch .Tensor
156
+ ), f"can only cast standard tensors not { type (tensor )} "
142
157
storage = tensor .untyped_storage ()
143
158
ret = torch .tensor (storage , dtype = dtype , device = tensor .device )
144
159
assert ret .untyped_storage () is storage , "storage should be the same"
@@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]):
150
165
This is a checkpoint transport that uses the process group to transfer checkpoints.
151
166
This allows for fast recovery of workers by fetching the current weights
152
167
from an existing worker.
168
+
153
169
Args:
154
- state_dict: a callable that returns the state dict to be transferred
170
+ pg: the process group to use for communication
171
+ timeout: the timeout for communication
172
+ device: the device to use for tensors
173
+ state_dict: if specified this function will be called to do an inplace
174
+ receive into the returned state_dict. This is much faster than
175
+ having to allocate new tensors and transferring them to the CPU.
155
176
"""
156
177
157
178
def __init__ (
158
- self , pg : ProcessGroup , timeout : timedelta , device : torch .device
179
+ self ,
180
+ pg : ProcessGroup ,
181
+ timeout : timedelta ,
182
+ device : torch .device ,
183
+ state_dict : Optional [Callable [[], object ]] = None ,
159
184
) -> None :
160
- self ._work : List [Work ] = []
185
+ self ._work : list [Work ] = []
161
186
self ._pg = pg
162
187
self ._timeout = timeout
163
188
self ._device = device
189
+ self ._state_dict = state_dict
164
190
165
191
def metadata (self ) -> str :
166
192
return "<n/a>"
@@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None:
169
195
pass
170
196
171
197
def send_checkpoint (
172
- self , dst_ranks : List [int ], step : int , state_dict : T , timeout : timedelta
198
+ self , dst_ranks : list [int ], step : int , state_dict : T , timeout : timedelta
173
199
) -> None :
174
200
with _timeit ("preparing state_dict" ):
175
201
meta , tensors = _prepare_state_dict (state_dict , step , device = self ._device )
@@ -186,20 +212,29 @@ def send_checkpoint(
186
212
187
213
with _timeit ("send tensors" ):
188
214
for i , t in enumerate (tensors ):
215
+ original_device = t .device
189
216
t = t .to (self ._device )
190
217
for dst_rank in dst_ranks :
191
218
work .append (self ._pg .send ([t ], dst_rank , tag = 3 + i ))
192
219
193
- # allow 3 concurrent transfers at a time to avoid OOMs
194
- while len (work ) > (3 * len (dst_ranks )):
195
- work .pop (0 ).wait (timeout )
220
+ # if we did a copy we should wait for the work to complete so we
221
+ # can free the memory to avoid OOMs
222
+ if original_device == torch .device ("cpu" ):
223
+ for w in work :
224
+ w .wait (timeout )
225
+ work = []
196
226
197
227
for w in work :
198
228
w .wait (timeout )
199
229
200
230
def recv_checkpoint (
201
231
self , src_rank : int , metadata : str , step : int , timeout : timedelta
202
232
) -> T :
233
+ state_dict = self ._state_dict () if self ._state_dict else {}
234
+ state_dict_leaves , _ = tree_flatten_with_path (state_dict )
235
+
236
+ dst_tensors : dict [KeyPath , object ] = dict (state_dict_leaves )
237
+
203
238
len_t = torch .zeros (1 , dtype = torch .int64 , device = self ._device )
204
239
self ._pg .recv ([len_t ], src_rank , tag = 1 ).wait (timeout )
205
240
length = cast (int , len_t .item ())
@@ -213,18 +248,34 @@ def recv_checkpoint(
213
248
assert meta .step == step
214
249
215
250
i : int = 0
251
+ works : list [Work ] = []
216
252
217
- def recv (v : _TensorMeta ) -> torch .Tensor :
253
+ def recv (path : KeyPath , v : _TensorMeta ) -> torch .Tensor :
218
254
nonlocal i
219
255
220
- t = torch .empty (v .nbytes , dtype = torch .uint8 , device = self ._device )
221
- # TODO: parallelize receives
222
- self ._pg .recv ([t ], src_rank , tag = 3 + i ).wait (timeout )
256
+ inplace = dst_tensors .get (path )
257
+ if (
258
+ isinstance (inplace , torch .Tensor )
259
+ and inplace .device .type == self ._device .type
260
+ ):
261
+ if isinstance (inplace , DTensor ):
262
+ inplace = inplace ._local_tensor
263
+ t = _cast_tensor (inplace , torch .uint8 )
264
+ assert (
265
+ t .nbytes == v .nbytes
266
+ ), "inplace tensor storage must be the same size"
267
+ else :
268
+ t = torch .empty (v .nbytes , dtype = torch .uint8 , device = self ._device )
269
+
270
+ work = self ._pg .recv ([t ], src_rank , tag = 3 + i )
223
271
i += 1
224
272
225
- # TODO: allow in place receives to avoid having to copy to cpu to
226
- # avoid OOMs
227
- t = t .cpu ()
273
+ if inplace is None :
274
+ # if not inplace we need to copy it to CPU to avoid OOMing
275
+ work .wait (timeout )
276
+ t = t .cpu ()
277
+ else :
278
+ works .append (work )
228
279
229
280
return torch .as_strided (
230
281
t .view (v .dtype ),
@@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor:
234
285
)
235
286
236
287
values = []
237
- for v in meta .non_tensor_leaves :
288
+ for path , v in zip ( meta .paths , meta . non_tensor_leaves ) :
238
289
if isinstance (v , _TensorMeta ):
239
- values .append (recv (v ))
290
+ values .append (recv (path , v ))
240
291
elif isinstance (v , _DTensorMeta ):
241
- tensor = recv (v .local )
292
+ tensor = recv (path , v .local )
242
293
# pyre-fixme[29]: DTensor is not a function
243
294
values .append (DTensor (tensor , v .spec , requires_grad = False ))
244
295
else :
245
296
values .append (v )
246
297
298
+ for work in works :
299
+ work .wait (timeout )
300
+
247
301
return tree_unflatten (values , meta .treespec )
0 commit comments