Skip to content
This repository was archived by the owner on Jan 6, 2023. It is now read-only.

Commit 25abb30

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
imagenet example - address comments (use makedirs, use io.BytesIO/torch.save,load instead of pickle)
Summary: Addressing comments from D20956704. 1. Use makedirs instead of mkdir 2. Use io.BytesIO + torch.save,load instead of pickle + bytearray 3. Set pg timeout == 10sec to prevent NCCL hang issues (watchdog kicks in at 10sec intervals) Reviewed By: tierex Differential Revision: D20959778 fbshipit-source-id: 1808f7852d65a5404802415fdd5a3fa0f1aa2f37
1 parent 1465dfc commit 25abb30

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

examples/imagenet/main.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@
4444
"""
4545

4646
import argparse
47+
import io
4748
import os
48-
import pickle
4949
import shutil
5050
import time
5151
from contextlib import contextmanager
52+
from datetime import timedelta
5253
from typing import List, Tuple
5354

55+
import numpy
5456
import torch
5557
import torch.backends.cudnn as cudnn
5658
import torch.distributed as dist
@@ -156,7 +158,9 @@ def main():
156158
# to enable a watchdog thread that will destroy stale NCCL communicators
157159
os.environ["NCCL_BLOCKING_WAIT"] = "1"
158160

159-
dist.init_process_group(backend=args.dist_backend, init_method="env://")
161+
dist.init_process_group(
162+
backend=args.dist_backend, init_method="env://", timeout=timedelta(seconds=10)
163+
)
160164

161165
model, criterion, optimizer = initialize_model(
162166
args.arch, args.lr, args.momentum, args.weight_decay, device_id
@@ -367,21 +371,26 @@ def load_checkpoint(
367371
# pickle the snapshot, convert it into a byte-blob tensor
368372
# then broadcast it, unpickle it and apply the snapshot
369373
print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")
370-
raw_blob = bytearray(pickle.dumps(state.capture_snapshot()))
374+
375+
with io.BytesIO() as f:
376+
torch.save(state.capture_snapshot(), f)
377+
raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8)
378+
371379
blob_len = torch.tensor(len(raw_blob))
372380
dist.broadcast(blob_len, src=max_rank, group=pg)
373381
print(f"=> checkpoint broadcast size is: {blob_len}")
374382

375383
if rank != max_rank:
376384
blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
377385
else:
378-
blob = torch.tensor(raw_blob, dtype=torch.uint8)
386+
blob = torch.as_tensor(raw_blob, dtype=torch.uint8)
379387

380388
dist.broadcast(blob, src=max_rank, group=pg)
381389
print(f"=> done broadcasting checkpoint")
382390

383391
if rank != max_rank:
384-
snapshot = pickle.loads(blob.numpy())
392+
with io.BytesIO(blob.numpy()) as f:
393+
snapshot = torch.load(f)
385394
state.apply_snapshot(snapshot, device_id)
386395

387396
# wait till everyone has loaded the checkpoint
@@ -402,7 +411,7 @@ def tmp_process_group(backend):
402411

403412
def save_checkpoint(state: State, is_best: bool, filename: str):
404413
checkpoint_dir = os.path.dirname(filename)
405-
os.mkdir(checkpoint_dir)
414+
os.makedirs(checkpoint_dir, exist_ok=True)
406415

407416
# save to tmp, then commit by moving the file in case the job
408417
# gets interrupted while writing the checkpoint

0 commit comments

Comments
 (0)