Skip to content
This repository was archived by the owner on Jan 6, 2023. It is now read-only.

Commit 1465dfc

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
imagenet example - add logic to broadcast most recent checkpoint from max_rank (#93)
Summary: Pull Request resolved: #93 Rationale for adding checkpoint broadcasting: - In our example we don't have access to globally visible storage - Each local rank =0 writes the checkpoint - when a container/node dies, the replacement container has no checkpoints (since it was lost with the node) - new nodes starts from scratch vs surviving nodes are ahead - the logic is to find the checkpoint with the max epoch and broadcast that Rationale for removing nnode==1 assertion for launcher with --with_etcd option. - you can run two agents on the same node to simulate a multi-node run - you can first start agent#1 by giving `--with_etcd` option - you can start agent#2 by copy pasting the rdzv info (from the logs) and passing the `--rdzv_id, --rdzv_backend, --rdzv_endpoint` from the first launch. Reviewed By: tierex, drdarshan Differential Revision: D20956704 fbshipit-source-id: 3170e1bcbedf1a7522f3aeee23f0fc67cd038253
1 parent 08a8187 commit 1465dfc

2 files changed

Lines changed: 166 additions & 46 deletions

File tree

examples/imagenet/main.py

Lines changed: 158 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545

4646
import argparse
4747
import os
48+
import pickle
4849
import shutil
4950
import time
51+
from contextlib import contextmanager
5052
from typing import List, Tuple
5153

5254
import torch
@@ -145,7 +147,6 @@
145147

146148
def main():
147149
args = parser.parse_args()
148-
149150
device_id = int(os.environ["LOCAL_RANK"])
150151
torch.cuda.set_device(device_id)
151152
print(f"=> set cuda device = {device_id}")
@@ -161,18 +162,21 @@ def main():
161162
args.arch, args.lr, args.momentum, args.weight_decay, device_id
162163
)
163164

164-
# resume from checkpoint if one exists
165-
start_epoch, best_acc1 = load_checkpoint(
166-
args.checkpoint_file, device_id, model, optimizer
167-
)
168-
print(f"=> start_epoch: {start_epoch}, best_acc1: {best_acc1}")
169-
170165
train_loader, val_loader = initialize_data_loader(
171166
args.data, args.batch_size, args.workers
172167
)
173168

169+
# resume from checkpoint if one exists;
170+
state = load_checkpoint(
171+
args.checkpoint_file, device_id, args.arch, model, optimizer
172+
)
173+
174+
start_epoch = state.epoch + 1
175+
print(f"=> start_epoch: {start_epoch}, best_acc1: {state.best_acc1}")
176+
174177
print_freq = args.print_freq
175178
for epoch in range(start_epoch, args.epochs):
179+
state.epoch = epoch
176180
train_loader.batch_sampler.sampler.set_epoch(epoch)
177181
adjust_learning_rate(optimizer, epoch, args.lr)
178182

@@ -183,21 +187,64 @@ def main():
183187
acc1 = validate(val_loader, model, criterion, device_id, print_freq)
184188

185189
# remember best acc@1 and save checkpoint
186-
is_best = acc1 > best_acc1
187-
best_acc1 = max(acc1, best_acc1)
190+
is_best = acc1 > state.best_acc1
191+
state.best_acc1 = max(acc1, state.best_acc1)
188192

189193
if device_id == 0:
190-
save_checkpoint(
191-
{
192-
"epoch": epoch + 1,
193-
"best_acc1": best_acc1,
194-
"arch": args.arch,
195-
"state_dict": model.state_dict(),
196-
"optimizer": optimizer.state_dict(),
197-
},
198-
is_best,
199-
args.checkpoint_file,
200-
)
194+
save_checkpoint(state, is_best, args.checkpoint_file)
195+
196+
197+
class State:
198+
"""
199+
Container for objects that we want to checkpoint. Represents the
200+
current "state" of the worker. This object is mutable.
201+
"""
202+
203+
def __init__(self, arch, model, optimizer):
204+
self.epoch = -1
205+
self.best_acc1 = 0
206+
self.arch = arch
207+
self.model = model
208+
self.optimizer = optimizer
209+
210+
def capture_snapshot(self):
211+
"""
212+
Essentially a ``serialize()`` function, returns the state as an
213+
object compatible with ``torch.save()``. The following should work
214+
::
215+
216+
snapshot = state_0.capture_snapshot()
217+
state_1.apply_snapshot(snapshot)
218+
assert state_0 == state_1
219+
"""
220+
return {
221+
"epoch": self.epoch,
222+
"best_acc1": self.best_acc1,
223+
"arch": self.arch,
224+
"state_dict": self.model.state_dict(),
225+
"optimizer": self.optimizer.state_dict(),
226+
}
227+
228+
def apply_snapshot(self, obj, device_id):
229+
"""
230+
The complimentary function of ``capture_snapshot()``. Applies the
231+
snapshot object that was returned by ``capture_snapshot()``.
232+
This function mutates this state object.
233+
"""
234+
235+
self.epoch = obj["epoch"]
236+
self.best_acc1 = obj["best_acc1"]
237+
self.state_dict = obj["state_dict"]
238+
self.model.load_state_dict(obj["state_dict"])
239+
self.optimizer.load_state_dict(obj["optimizer"])
240+
241+
def save(self, f):
242+
torch.save(self.capture_snapshot(), f)
243+
244+
def load(self, f, device_id):
245+
# Map model to be loaded to specified single gpu.
246+
snapshot = torch.load(f, map_location=f"cuda:{device_id}")
247+
self.apply_snapshot(snapshot, device_id)
201248

202249

203250
def initialize_model(
@@ -273,21 +320,100 @@ def initialize_data_loader(
273320
def load_checkpoint(
274321
checkpoint_file: str,
275322
device_id: int,
323+
arch: str,
276324
model: DistributedDataParallel,
277325
optimizer, # SGD
278-
) -> Tuple[int, float]:
279-
start_epoch = 0
280-
best_acc1 = 0
326+
) -> State:
327+
"""
328+
Loads a local checkpoint (if any). Otherwise, checks to see if any of
329+
the neighbors have a non-zero state. If so, restore the state
330+
from the rank that has the most up-to-date checkpoint.
331+
332+
.. note:: when your job has access to a globally visible persistent storage
333+
(e.g. nfs mount, S3) you can simply have all workers load
334+
from the most recent checkpoint from such storage. Since this
335+
example is expected to run on vanilla hosts (with no shared
336+
storage) the checkpoints are written to local disk, hence
337+
we have the extra logic to broadcast the checkpoint from a
338+
surviving node.
339+
"""
340+
341+
state = State(arch, model, optimizer)
342+
281343
if os.path.isfile(checkpoint_file):
282-
print(f"=> loading checkpoint: {checkpoint_file}")
283-
# Map model to be loaded to specified single gpu.
284-
checkpoint = torch.load(checkpoint_file, map_location=f"cuda:{device_id}")
285-
start_epoch = checkpoint["epoch"]
286-
best_acc1 = checkpoint["best_acc1"]
287-
model.load_state_dict(checkpoint["state_dict"])
288-
optimizer.load_state_dict(checkpoint["optimizer"])
289-
print(f"=> loaded checkpoint: {checkpoint_file}")
290-
return start_epoch, best_acc1
344+
print(f"=> loading checkpoint file: {checkpoint_file}")
345+
state.load(checkpoint_file, device_id)
346+
print(f"=> loaded checkpoint file: {checkpoint_file}")
347+
348+
# logic below is unnecessary when the checkpoint is visible on all nodes!
349+
# create a temporary cpu pg to broadcast most up-to-date checkpoint
350+
with tmp_process_group(backend="gloo") as pg:
351+
rank = dist.get_rank(group=pg)
352+
353+
# get rank that has the largest state.epoch
354+
epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
355+
epochs[rank] = state.epoch
356+
dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
357+
t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
358+
max_epoch = t_max_epoch.item()
359+
max_rank = t_max_rank.item()
360+
361+
# max_epoch == -1 means no one has checkpointed return base state
362+
if max_epoch == -1:
363+
print(f"=> no workers have checkpoints, starting from epoch 0")
364+
return state
365+
366+
# broadcast the state from max_rank (which has the most up-to-date state)
367+
# pickle the snapshot, convert it into a byte-blob tensor
368+
# then broadcast it, unpickle it and apply the snapshot
369+
print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")
370+
raw_blob = bytearray(pickle.dumps(state.capture_snapshot()))
371+
blob_len = torch.tensor(len(raw_blob))
372+
dist.broadcast(blob_len, src=max_rank, group=pg)
373+
print(f"=> checkpoint broadcast size is: {blob_len}")
374+
375+
if rank != max_rank:
376+
blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
377+
else:
378+
blob = torch.tensor(raw_blob, dtype=torch.uint8)
379+
380+
dist.broadcast(blob, src=max_rank, group=pg)
381+
print(f"=> done broadcasting checkpoint")
382+
383+
if rank != max_rank:
384+
snapshot = pickle.loads(blob.numpy())
385+
state.apply_snapshot(snapshot, device_id)
386+
387+
# wait till everyone has loaded the checkpoint
388+
dist.barrier(group=pg)
389+
390+
print(f"=> done restoring from previous checkpoint")
391+
return state
392+
393+
394+
@contextmanager
395+
def tmp_process_group(backend):
396+
cpu_pg = dist.new_group(backend=backend)
397+
try:
398+
yield cpu_pg
399+
finally:
400+
dist.destroy_process_group(cpu_pg)
401+
402+
403+
def save_checkpoint(state: State, is_best: bool, filename: str):
404+
checkpoint_dir = os.path.dirname(filename)
405+
os.mkdir(checkpoint_dir)
406+
407+
# save to tmp, then commit by moving the file in case the job
408+
# gets interrupted while writing the checkpoint
409+
tmp_filename = filename + ".tmp"
410+
torch.save(state.capture_snapshot(), tmp_filename)
411+
os.rename(tmp_filename, filename)
412+
print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")
413+
if is_best:
414+
best = os.path.join(checkpoint_dir, "model_best.pth.tar")
415+
print(f"=> best model found at epoch {state.epoch} saving to {best}")
416+
shutil.copyfile(filename, best)
291417

292418

293419
def train(
@@ -394,16 +520,6 @@ def validate(
394520
return top1.avg
395521

396522

397-
def save_checkpoint(state, is_best: bool, filename: str):
398-
# save to tmp, then commit by moving the file in case the job
399-
# gets interrupted while writing the checkpoint
400-
tmp_filename = filename + ".tmp"
401-
torch.save(state, tmp_filename)
402-
os.rename(tmp_filename, filename)
403-
if is_best:
404-
shutil.copyfile(filename, "model_best.pth.tar")
405-
406-
407523
class AverageMeter(object):
408524
"""Computes and stores the average and current value"""
409525

torchelastic/distributed/launch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,19 @@ def main(args=None):
431431
assert args.max_restarts > 0
432432

433433
if args.with_etcd:
434-
assert (
435-
min_nodes == max_nodes == 1
436-
), "--with_etcd can only be used with --nodes=1"
437-
438434
etcd_server = EtcdServer()
439435
etcd_server.start()
440436
args.rdzv_backend = "etcd"
441437
args.rdzv_endpoint = etcd_server.get_endpoint()
442438
args.rdzv_id = str(uuid.uuid4())
439+
log.info(
440+
f"\n**************************************\n"
441+
f"Rendezvous info:\n"
442+
f"--rdzv_backend={args.rdzv_backend} "
443+
f"--rdzv_endpoint={args.rdzv_endpoint} "
444+
f"--rdzv_id={args.rdzv_id}\n"
445+
f"**************************************\n"
446+
)
443447

444448
rdzv_parameters = parameters.RendezvousParameters(
445449
args.rdzv_backend,

0 commit comments

Comments
 (0)