File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -161,7 +161,13 @@ def func(x):
161161 # Normalize the gradient according to its norm (computed in another pass)
162162 grad_fp32 .mul_ (self .clip_coef )
163163
164- beta2t = 1.0 - math .pow (self .step_num , self .decay_rate )
164+ # To avoid math errors for edge cases
165+ if self .step_num == 0 and self .decay_rate < 0 :
166+ decay_rate = - self .decay_rate
167+ else :
168+ decay_rate = self .decay_rate
169+
170+ beta2t = 1.0 - math .pow (self .step_num , decay_rate )
165171 update = (grad_fp32 ** 2 ) + self .eps [0 ]
166172
167173 if len (p .data .shape ) > 1 :
@@ -243,7 +249,12 @@ def func(x):
243249 # Normalize the gradient according to its norm (computed in another pass)
244250 grad_fp32 .mul_ (self .clip_coef )
245251
246- beta2t = 1.0 - math .pow (self .step_num , self .decay_rate )
252+ # To avoid math errors for edge cases
253+ if self .step_num == 0 and self .decay_rate < 0 :
254+ decay_rate = - self .decay_rate
255+ else :
256+ decay_rate = self .decay_rate
257+ beta2t = 1.0 - math .pow (self .step_num , decay_rate )
247258 update = (grad_fp32 ** 2 ) + self .eps [0 ] # 改成addcmul_
248259
249260 if len (p .ds_shape ) > 1 :
You can’t perform that action at this time.
0 commit comments