4545
4646import argparse
4747import os
48+ import pickle
4849import shutil
4950import time
51+ from contextlib import contextmanager
5052from typing import List , Tuple
5153
5254import torch
145147
146148def main ():
147149 args = parser .parse_args ()
148-
149150 device_id = int (os .environ ["LOCAL_RANK" ])
150151 torch .cuda .set_device (device_id )
151152 print (f"=> set cuda device = { device_id } " )
@@ -161,18 +162,21 @@ def main():
161162 args .arch , args .lr , args .momentum , args .weight_decay , device_id
162163 )
163164
164- # resume from checkpoint if one exists
165- start_epoch , best_acc1 = load_checkpoint (
166- args .checkpoint_file , device_id , model , optimizer
167- )
168- print (f"=> start_epoch: { start_epoch } , best_acc1: { best_acc1 } " )
169-
170165 train_loader , val_loader = initialize_data_loader (
171166 args .data , args .batch_size , args .workers
172167 )
173168
169+ # resume from checkpoint if one exists;
170+ state = load_checkpoint (
171+ args .checkpoint_file , device_id , args .arch , model , optimizer
172+ )
173+
174+ start_epoch = state .epoch + 1
175+ print (f"=> start_epoch: { start_epoch } , best_acc1: { state .best_acc1 } " )
176+
174177 print_freq = args .print_freq
175178 for epoch in range (start_epoch , args .epochs ):
179+ state .epoch = epoch
176180 train_loader .batch_sampler .sampler .set_epoch (epoch )
177181 adjust_learning_rate (optimizer , epoch , args .lr )
178182
@@ -183,21 +187,64 @@ def main():
183187 acc1 = validate (val_loader , model , criterion , device_id , print_freq )
184188
185189 # remember best acc@1 and save checkpoint
186- is_best = acc1 > best_acc1
187- best_acc1 = max (acc1 , best_acc1 )
190+ is_best = acc1 > state . best_acc1
191+ state . best_acc1 = max (acc1 , state . best_acc1 )
188192
189193 if device_id == 0 :
190- save_checkpoint (
191- {
192- "epoch" : epoch + 1 ,
193- "best_acc1" : best_acc1 ,
194- "arch" : args .arch ,
195- "state_dict" : model .state_dict (),
196- "optimizer" : optimizer .state_dict (),
197- },
198- is_best ,
199- args .checkpoint_file ,
200- )
194+ save_checkpoint (state , is_best , args .checkpoint_file )
195+
196+
197+ class State :
198+ """
199+ Container for objects that we want to checkpoint. Represents the
200+ current "state" of the worker. This object is mutable.
201+ """
202+
203+ def __init__ (self , arch , model , optimizer ):
204+ self .epoch = - 1
205+ self .best_acc1 = 0
206+ self .arch = arch
207+ self .model = model
208+ self .optimizer = optimizer
209+
210+ def capture_snapshot (self ):
211+ """
212+ Essentially a ``serialize()`` function, returns the state as an
213+ object compatible with ``torch.save()``. The following should work
214+ ::
215+
216+ snapshot = state_0.capture_snapshot()
217+ state_1.apply_snapshot(snapshot)
218+ assert state_0 == state_1
219+ """
220+ return {
221+ "epoch" : self .epoch ,
222+ "best_acc1" : self .best_acc1 ,
223+ "arch" : self .arch ,
224+ "state_dict" : self .model .state_dict (),
225+ "optimizer" : self .optimizer .state_dict (),
226+ }
227+
228+ def apply_snapshot (self , obj , device_id ):
229+ """
230+ The complimentary function of ``capture_snapshot()``. Applies the
231+ snapshot object that was returned by ``capture_snapshot()``.
232+ This function mutates this state object.
233+ """
234+
235+ self .epoch = obj ["epoch" ]
236+ self .best_acc1 = obj ["best_acc1" ]
237+ self .state_dict = obj ["state_dict" ]
238+ self .model .load_state_dict (obj ["state_dict" ])
239+ self .optimizer .load_state_dict (obj ["optimizer" ])
240+
241+ def save (self , f ):
242+ torch .save (self .capture_snapshot (), f )
243+
244+ def load (self , f , device_id ):
245+ # Map model to be loaded to specified single gpu.
246+ snapshot = torch .load (f , map_location = f"cuda:{ device_id } " )
247+ self .apply_snapshot (snapshot , device_id )
201248
202249
203250def initialize_model (
@@ -273,21 +320,100 @@ def initialize_data_loader(
273320def load_checkpoint (
274321 checkpoint_file : str ,
275322 device_id : int ,
323+ arch : str ,
276324 model : DistributedDataParallel ,
277325 optimizer , # SGD
278- ) -> Tuple [int , float ]:
279- start_epoch = 0
280- best_acc1 = 0
326+ ) -> State :
327+ """
328+ Loads a local checkpoint (if any). Otherwise, checks to see if any of
329+ the neighbors have a non-zero state. If so, restore the state
330+ from the rank that has the most up-to-date checkpoint.
331+
332+ .. note:: when your job has access to a globally visible persistent storage
333+ (e.g. nfs mount, S3) you can simply have all workers load
334+ from the most recent checkpoint from such storage. Since this
335+ example is expected to run on vanilla hosts (with no shared
336+ storage) the checkpoints are written to local disk, hence
337+ we have the extra logic to broadcast the checkpoint from a
338+ surviving node.
339+ """
340+
341+ state = State (arch , model , optimizer )
342+
281343 if os .path .isfile (checkpoint_file ):
282- print (f"=> loading checkpoint: { checkpoint_file } " )
283- # Map model to be loaded to specified single gpu.
284- checkpoint = torch .load (checkpoint_file , map_location = f"cuda:{ device_id } " )
285- start_epoch = checkpoint ["epoch" ]
286- best_acc1 = checkpoint ["best_acc1" ]
287- model .load_state_dict (checkpoint ["state_dict" ])
288- optimizer .load_state_dict (checkpoint ["optimizer" ])
289- print (f"=> loaded checkpoint: { checkpoint_file } " )
290- return start_epoch , best_acc1
344+ print (f"=> loading checkpoint file: { checkpoint_file } " )
345+ state .load (checkpoint_file , device_id )
346+ print (f"=> loaded checkpoint file: { checkpoint_file } " )
347+
348+ # logic below is unnecessary when the checkpoint is visible on all nodes!
349+ # create a temporary cpu pg to broadcast most up-to-date checkpoint
350+ with tmp_process_group (backend = "gloo" ) as pg :
351+ rank = dist .get_rank (group = pg )
352+
353+ # get rank that has the largest state.epoch
354+ epochs = torch .zeros (dist .get_world_size (), dtype = torch .int32 )
355+ epochs [rank ] = state .epoch
356+ dist .all_reduce (epochs , op = dist .ReduceOp .SUM , group = pg )
357+ t_max_epoch , t_max_rank = torch .max (epochs , dim = 0 )
358+ max_epoch = t_max_epoch .item ()
359+ max_rank = t_max_rank .item ()
360+
361+ # max_epoch == -1 means no one has checkpointed return base state
362+ if max_epoch == - 1 :
363+ print (f"=> no workers have checkpoints, starting from epoch 0" )
364+ return state
365+
366+ # broadcast the state from max_rank (which has the most up-to-date state)
367+ # pickle the snapshot, convert it into a byte-blob tensor
368+ # then broadcast it, unpickle it and apply the snapshot
369+ print (f"=> using checkpoint from rank: { max_rank } , max_epoch: { max_epoch } " )
370+ raw_blob = bytearray (pickle .dumps (state .capture_snapshot ()))
371+ blob_len = torch .tensor (len (raw_blob ))
372+ dist .broadcast (blob_len , src = max_rank , group = pg )
373+ print (f"=> checkpoint broadcast size is: { blob_len } " )
374+
375+ if rank != max_rank :
376+ blob = torch .zeros (blob_len .item (), dtype = torch .uint8 )
377+ else :
378+ blob = torch .tensor (raw_blob , dtype = torch .uint8 )
379+
380+ dist .broadcast (blob , src = max_rank , group = pg )
381+ print (f"=> done broadcasting checkpoint" )
382+
383+ if rank != max_rank :
384+ snapshot = pickle .loads (blob .numpy ())
385+ state .apply_snapshot (snapshot , device_id )
386+
387+ # wait till everyone has loaded the checkpoint
388+ dist .barrier (group = pg )
389+
390+ print (f"=> done restoring from previous checkpoint" )
391+ return state
392+
393+
394+ @contextmanager
395+ def tmp_process_group (backend ):
396+ cpu_pg = dist .new_group (backend = backend )
397+ try :
398+ yield cpu_pg
399+ finally :
400+ dist .destroy_process_group (cpu_pg )
401+
402+
403+ def save_checkpoint (state : State , is_best : bool , filename : str ):
404+ checkpoint_dir = os .path .dirname (filename )
405+ os .mkdir (checkpoint_dir )
406+
407+ # save to tmp, then commit by moving the file in case the job
408+ # gets interrupted while writing the checkpoint
409+ tmp_filename = filename + ".tmp"
410+ torch .save (state .capture_snapshot (), tmp_filename )
411+ os .rename (tmp_filename , filename )
412+ print (f"=> saved checkpoint for epoch { state .epoch } at { filename } " )
413+ if is_best :
414+ best = os .path .join (checkpoint_dir , "model_best.pth.tar" )
415+ print (f"=> best model found at epoch { state .epoch } saving to { best } " )
416+ shutil .copyfile (filename , best )
291417
292418
293419def train (
@@ -394,16 +520,6 @@ def validate(
394520 return top1 .avg
395521
396522
397- def save_checkpoint (state , is_best : bool , filename : str ):
398- # save to tmp, then commit by moving the file in case the job
399- # gets interrupted while writing the checkpoint
400- tmp_filename = filename + ".tmp"
401- torch .save (state , tmp_filename )
402- os .rename (tmp_filename , filename )
403- if is_best :
404- shutil .copyfile (filename , "model_best.pth.tar" )
405-
406-
407523class AverageMeter (object ):
408524 """Computes and stores the average and current value"""
409525
0 commit comments