@@ -43,6 +43,7 @@ def sawtooth_warmup_cosine_decay_schedule(
4343 Phase 2: Both encoder and decoder warmup
4444 Phase 3: Cosine annealing for both
4545 """
46+ assert max_epochs > 0 and steps_per_epoch > 0 , "max_epochs and steps_per_epoch must be greater than 0"
4647 print (f"Using separate warmup: decoder for { decoder_warmup_epochs } epochs, then both for { warmup_epochs } epochs" )
4748
4849 decoder_warmup_steps = int (decoder_warmup_epochs * steps_per_epoch )
@@ -72,16 +73,24 @@ def decoder_phase1_lambda(step):
7273 )
7374
7475
75- def simple_warmup_cosine_decay_schedule (optimizer , warmup_epochs , steps_per_epoch , cosine_period_ratio , max_epochs ):
76+ def simple_warmup_cosine_decay_schedule (
77+ optimizer , warmup_epochs , steps_per_epoch , cosine_period_ratio , max_epochs = - 1 , max_steps = - 1
78+ ):
7679 """
7780 Phase 1: Warmup for both encoder and decoder
7881 Phase 2: Cosine annealing for both
7982 """
80- print (f"Using warmup for { warmup_epochs } epochs" )
83+ assert warmup_epochs >= 0 , "Warmup epochs must be greater than or equal to 0."
84+ assert cosine_period_ratio > 0 , "Cosine period ratio must be greater than 0."
85+ assert steps_per_epoch > 0 , "Steps per epoch must be greater than 0."
86+ assert max_epochs > 0 or max_steps > 0 , "Either max_epochs or max_steps must be greater than 0."
8187
82- total_warmup_steps = int (warmup_epochs * steps_per_epoch )
8388 # cosine_half_period is from max to min
84- cosine_steps = int (cosine_period_ratio * (max_epochs * steps_per_epoch - total_warmup_steps ))
89+ if max_epochs > 0 :
90+ max_steps = max_epochs * steps_per_epoch
91+
92+ total_warmup_steps = int (warmup_epochs * steps_per_epoch )
93+ cosine_steps = int (cosine_period_ratio * (max_steps - total_warmup_steps ))
8594
8695 cosine_scheduler = CosineAnnealingLR (optimizer , T_max = cosine_steps )
8796 warmup_scheduler = LinearLR (
@@ -90,17 +99,25 @@ def simple_warmup_cosine_decay_schedule(optimizer, warmup_epochs, steps_per_epoc
9099 total_iters = total_warmup_steps ,
91100 )
92101
102+ print (f"Using warmup for { warmup_epochs } epochs ({ total_warmup_steps } steps)" )
103+ print (f"Cosine decay for { cosine_steps } steps after warmup" )
104+ assert total_warmup_steps > 0 , "Warmup steps must be greater than 0 for warmup schedule."
105+ assert cosine_steps > 0 , "Cosine steps must be greater than 0 for warmup cosine decay schedule."
106+
93107 return SequentialLR (
94108 optimizer ,
95109 schedulers = [warmup_scheduler , cosine_scheduler ],
96110 milestones = [total_warmup_steps ],
97111 )
98112
99113
100- def cosine_decay_schedule (optimizer , steps_per_epoch , cosine_period_ratio , max_epochs ):
114+ def cosine_decay_schedule (optimizer , steps_per_epoch , cosine_period_ratio , max_epochs = - 1 , max_steps = - 1 ):
101115 """
102116 Phase 1: Cosine annealing for both encoder and decoder
103117 """
104118 # cosine_half_period is from max to min
105- cosine_steps = int (cosine_period_ratio * (max_epochs * steps_per_epoch ))
119+ if max_epochs > 0 :
120+ max_steps = max_epochs * steps_per_epoch
121+ cosine_steps = int (cosine_period_ratio * max_steps )
122+ assert cosine_steps > 0 , "Cosine steps must be greater than 0 for cosine decay schedule."
106123 return CosineAnnealingLR (optimizer , T_max = cosine_steps )
0 commit comments