@@ -129,6 +129,11 @@ def main(args, resume_preempt=False):
129129
130130 # -- OPTIMIZATION
131131 cfgs_opt = args .get ("optimization" )
132+ is_anneal = cfgs_opt .get ("is_anneal" , False )
133+ anneal_ckpt = cfgs_opt .get ("anneal_ckpt" , None )
134+ if is_anneal and anneal_ckpt is None :
135+ raise ValueError ("Must specify anneal_ckpt if is_anneal is True" )
136+ resume_anneal = cfgs_opt .get ("resume_anneal" , False ) or (is_anneal and resume_preempt )
132137 ipe = cfgs_opt .get ("ipe" , None )
133138 ipe_scale = cfgs_opt .get ("ipe_scale" , 1.0 )
134139 wd = float (cfgs_opt .get ("weight_decay" ))
@@ -169,7 +174,14 @@ def main(args, resume_preempt=False):
169174 latest_path = os .path .join (folder , latest_file )
170175 load_path = None
171176 if load_model :
172- load_path = os .path .join (folder , r_file ) if r_file is not None else latest_path
177+ if is_anneal :
178+ if os .path .exists (latest_path ) and resume_anneal :
179+ load_path = latest_path
180+ else :
181+ load_path = anneal_ckpt
182+ resume_anneal = False
183+ else :
184+ load_path = r_file if r_file is not None else latest_path
173185 if not os .path .exists (load_path ):
174186 load_path = None
175187 load_model = False
@@ -261,6 +273,7 @@ def main(args, resume_preempt=False):
261273
262274 # -- init optimizer and scheduler
263275 optimizer , scaler , scheduler , wd_scheduler = init_opt (
276+ is_anneal = is_anneal ,
264277 encoder = encoder ,
265278 predictor = predictor ,
266279 wd = wd ,
@@ -305,12 +318,14 @@ def main(args, resume_preempt=False):
305318 target_encoder = target_encoder ,
306319 opt = optimizer ,
307320 scaler = scaler ,
321+ is_anneal = is_anneal and not resume_anneal ,
308322 )
309- for _ in range (start_epoch * ipe ):
310- scheduler .step ()
311- wd_scheduler .step ()
312- next (momentum_scheduler )
313- mask_collator .step ()
323+ if not is_anneal or resume_anneal :
324+ for _ in range (start_epoch * ipe ):
325+ scheduler .step ()
326+ wd_scheduler .step ()
327+ next (momentum_scheduler )
328+ mask_collator .step ()
314329
315330 def save_checkpoint (epoch , path ):
316331 if rank != 0 :
0 commit comments