Skip to content

Commit ebbf410

Browse files
authored
Merge pull request #78 from younesbelkada/patch-1
FIX: avoid math errors for edge cases
2 parents 2b826c5 + ad755c6 commit ebbf410

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

lomo_optim/adalomo.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def func(x):
161161
# Normalize the gradient according to its norm (computed in another pass)
162162
grad_fp32.mul_(self.clip_coef)
163163

164-
beta2t = 1.0 - math.pow(self.step_num, self.decay_rate)
164+
# To avoid math errors for edge cases
165+
if self.step_num == 0 and self.decay_rate < 0:
166+
decay_rate = - self.decay_rate
167+
else:
168+
decay_rate = self.decay_rate
169+
170+
beta2t = 1.0 - math.pow(self.step_num, decay_rate)
165171
update = (grad_fp32**2) + self.eps[0]
166172

167173
if len(p.data.shape) > 1:
@@ -243,7 +249,12 @@ def func(x):
243249
# Normalize the gradient according to its norm (computed in another pass)
244250
grad_fp32.mul_(self.clip_coef)
245251

246-
beta2t = 1.0 - math.pow(self.step_num, self.decay_rate)
252+
# To avoid math errors for edge cases
253+
if self.step_num == 0 and self.decay_rate < 0:
254+
decay_rate = - self.decay_rate
255+
else:
256+
decay_rate = self.decay_rate
257+
beta2t = 1.0 - math.pow(self.step_num, decay_rate)
247258
update = (grad_fp32**2) + self.eps[0] # 改成addcmul_
248259

249260
if len(p.ds_shape) > 1:

0 commit comments

Comments
 (0)