@@ -271,38 +271,49 @@ class CheckpointSaver(Callback):
271271
272272 Args:
273273 save_dir (str): path to folder where to save the model
274- save_name (str): name of the saved model. can additionally
274+ save_name (str): name of the saved model. can additionally
275275 add epoch and metric to model save name
276+ monitor (str): quantity to monitor. Implicitly prefers validation metrics over train. One of:
277+ `loss` or name of any metric passed to the runner.
276278 mode (str): one of "min" of "max". Whether to decide to save based
277279 on minimizing or maximizing loss
278- include_optimizer (bool): if True would also save `optimizers` state_dict.
280+ include_optimizer (bool): if True would also save `optimizers` state_dict.
279281 This increases checkpoint size 2x times.
282+ verbose (bool): If `True` reports each time new best is found
280283 """
281284
282285 def __init__ (
283- self , save_dir , save_name = "model_{ep}_{metric:.2f}.chpn" , mode = "min" , include_optimizer = False
286+ self ,
287+ save_dir ,
288+ save_name = "model_{ep}_{metric:.2f}.chpn" ,
289+ monitor = "loss" ,
290+ mode = "min" ,
291+ include_optimizer = False ,
292+ verbose = True ,
284293 ):
285294 super ().__init__ ()
286295 self .save_dir = save_dir
287296 self .save_name = save_name
288- self .mode = ReduceMode (mode )
289- self .best = float ("inf" ) if self .mode == ReduceMode .MIN else - float ("inf" )
297+ self .monitor = monitor
298+ mode = ReduceMode (mode )
299+ if mode == ReduceMode .MIN :
300+ self .best = np .inf
301+ self .monitor_op = np .less
302+ elif mode == ReduceMode .MAX :
303+ self .best = - np .inf
304+ self .monitor_op = np .greater
290305 self .include_optimizer = include_optimizer
306+ self .verbose = verbose
291307
292308 def on_begin (self ):
293309 os .makedirs (self .save_dir , exist_ok = True )
294310
295311 def on_epoch_end (self ):
296- # TODO zakirov(1.11.19) Add support for saving based on metric
297- if self .state .val_loss is not None :
298- current = self .state .val_loss .avg
299- else :
300- current = self .state .train_loss .avg
301- if (self .mode == ReduceMode .MIN and current < self .best ) or (
302- self .mode == ReduceMode .MAX and current > self .best
303- ):
304- ep = self .state .epoch
305- # print(f"Epoch {ep}: best loss improved from {self.best:.4f} to {current:.4f}")
312+ current = self .get_monitor_value ()
313+ if self .monitor_op (current , self .best ):
314+ ep = self .state .epoch_log
315+ if self .verbose :
316+ print (f"Epoch { ep :2d} : best { self .monitor } improved from { self .best :.4f} to { current :.4f} " )
306317 self .best = current
307318 save_name = os .path .join (self .save_dir , self .save_name .format (ep = ep , metric = current ))
308319 self ._save_checkpoint (save_name )
@@ -317,6 +328,18 @@ def _save_checkpoint(self, path):
317328 save_dict ["optimizer" ] = self .state .optimizer .state_dict ()
318329 torch .save (save_dict , path )
319330
331+ def get_monitor_value (self ):
332+ value = None
333+ if self .monitor == "loss" :
334+ value = self .state .loss_meter .avg
335+ else :
336+ for metric_meter in self .state .metric_meters :
337+ if metric_meter .name == self .monitor :
338+ value = metric_meter .avg
339+ if value is None :
340+ raise ValueError (f"CheckpointSaver can't find { self .monitor } value to monitor" )
341+ return value
342+
320343
321344class TensorBoard (Callback ):
322345 """
@@ -407,7 +430,7 @@ def on_batch_end(self):
407430
408431 def on_loader_end (self ):
409432 super ().on_loader_end ()
410- f = plot_confusion_matrix (self .cmap , self .class_names , show = False )
433+ f = plot_confusion_matrix (self .cmap , self .class_names , normalize = True , show = False )
411434 cm_img = render_figure_to_tensor (f )
412435 if self .state .is_train :
413436 self .train_cm_img = cm_img
@@ -527,10 +550,11 @@ def mixup(self, data, target):
527550 if not self .state .is_train or np .random .rand () > self .prob :
528551 return data , target_one_hot
529552 prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
530- self .prev_input = data , target_one_hot
553+ self .prev_input = data .clone (), target_one_hot .clone ()
554+ perm = torch .randperm (data .size (0 )).cuda ()
531555 c = self .tb .sample ()
532- md = c * data + (1 - c ) * prev_data
533- mt = c * target_one_hot + (1 - c ) * prev_target
556+ md = c * data + (1 - c ) * prev_data [ perm ]
557+ mt = c * target_one_hot + (1 - c ) * prev_target [ perm ]
534558 return md , mt
535559
536560
@@ -570,16 +594,17 @@ def cutmix(self, data, target):
570594 if not self .state .is_train or np .random .rand () > self .prob :
571595 return data , target_one_hot
572596 prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
573- self .prev_input = data , target_one_hot
597+ self .prev_input = data . clone () , target_one_hot . clone ()
574598 # prev_data shape can be different from current. so need to take min
575599 H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
600+ perm = torch .randperm (data .size (0 )).cuda ()
576601 lam = self .tb .sample ()
577602 lam = min ([lam , 1 - lam ])
578603 bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
579604 # real lambda may be diffrent from sampled. adjust for it
580605 lam = (bbh2 - bbh1 ) * (bbw2 - bbw1 ) / (H * W )
581- data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
582- mixed_target = (1 - lam ) * target_one_hot + lam * prev_target
606+ data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
607+ mixed_target = (1 - lam ) * target_one_hot + lam * prev_target [ perm ]
583608 return data , mixed_target
584609
585610 @staticmethod
@@ -609,11 +634,32 @@ def cutmix(self, data, target):
609634 if not self .state .is_train or np .random .rand () > self .prob :
610635 return data , target
611636 prev_data , prev_target = (data , target ) if self .prev_input is None else self .prev_input
612- self .prev_input = data , target
637+ self .prev_input = data . clone () , target . clone ()
613638 H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
639+ perm = torch .randperm (data .size (0 )).cuda ()
614640 lam = self .tb .sample ()
615641 lam = min ([lam , 1 - lam ])
616642 bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
617- data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
618- target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_target [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
643+ data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
644+ target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_target [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
619645 return data , target
646+
647+
648+ class ScheduledDropout (Callback ):
649+ def __init__ (self , drop_rate = 0.1 , epochs = 30 , attr_name = "dropout.p" ):
650+ """
651+ Slowly changes dropout value for `attr_name` each epoch.
652+ Ref: https://arxiv.org/abs/1703.06229
653+ Args:
654+ drop_rate (float): max dropout rate
655+ epochs (int): num epochs to max dropout to fully take effect
656+ attr_name (str): name of dropout block in model
657+ """
658+ super ().__init__ ()
659+ self .drop_rate = drop_rate
660+ self .epochs = epochs
661+ self .attr_name = attr_name
662+
663+ def on_epoch_end (self ):
664+ current_rate = self .drop_rate * min (1 , self .state .epoch / self .epochs )
665+ setattr (self .state .model , self .attr_name , current_rate )
0 commit comments