@@ -395,10 +395,7 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
395395
396396
397397class Adam (Optimizer ):
398- r"""The Adam optimizer [1].
399-
400- Our Adam implementation follows the original paper and omits the bias
401- correction in the first and second moment estimates. In detail,
398+ r"""The Adam optimizer [1]. In detail,
402399
403400 [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
404401 optimization. ICLR 2015.
@@ -416,19 +413,23 @@ class Adam(Optimizer):
416413 gradient and its square. Default: ``(0.9, 0.999)``
417414 eps (float, optional): The term :math:`\epsilon` added to the
418415 denominator to improve numerical stability. Default: ``1e-8``
416+ bias_correction (bool, optional): If set to ``True``, bias correction
417+ is applied. Default: ``False``
419418 """
420419
421420 def __init__ (
422421 self ,
423422 learning_rate : Union [float , Callable [[mx .array ], mx .array ]],
424423 betas : List [float ] = [0.9 , 0.999 ],
425424 eps : float = 1e-8 ,
425+ bias_correction : bool = False ,
426426 ):
427427 super ().__init__ ()
428428
429429 self ._maybe_schedule ("learning_rate" , learning_rate )
430430 self .betas = betas
431431 self .eps = eps
432+ self .bias_correction = bias_correction
432433
433434 def init_single (self , parameter : mx .array , state : dict ):
434435 """Initialize optimizer state"""
@@ -441,6 +442,8 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
441442 lr = self .learning_rate .astype (gradient .dtype )
442443 b1 , b2 = self .betas
443444 eps = self .eps
445+ bias_correction = self .bias_correction
446+ step = self .step
444447
445448 m = state ["m" ]
446449 v = state ["v" ]
@@ -449,15 +452,17 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
449452 state ["m" ] = m
450453 state ["v" ] = v
451454
452- return parameter - lr * m / (mx .sqrt (v ) + eps )
455+ if bias_correction :
456+ numerator = lr / (1 - b1 ** step ) * m
457+ denominator = mx .sqrt (v ) / mx .sqrt (1 - b2 ** step ) + eps
458+ return parameter - numerator / denominator
459+ else :
460+ return parameter - lr * m / (mx .sqrt (v ) + eps )
453461
454462
455463class AdamW (Adam ):
456- r"""The AdamW optimizer [1].
457-
458- Following the above convention, in contrast with [1], we do not use bias
459- correction in the first and second moments for AdamW. We update the weights
460- with a weight_decay (:math:`\lambda`) value:
464+ r"""The AdamW optimizer [1]. We update the weights with a weight_decay
465+ (:math:`\lambda`) value:
461466
462467 [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
463468 regularization. ICLR 2019.
@@ -477,6 +482,8 @@ class AdamW(Adam):
477482 denominator to improve numerical stability. Default: ``1e-8``
478483 weight_decay (float, optional): The weight decay :math:`\lambda`.
479484 Default: ``0``.
485+ bias_correction (bool, optional): If set to ``True``, bias correction
486+ is applied. Default: ``False``
480487 """
481488
482489 def __init__ (
@@ -485,8 +492,14 @@ def __init__(
485492 betas : List [float ] = [0.9 , 0.999 ],
486493 eps : float = 1e-8 ,
487494 weight_decay : float = 0.01 ,
495+ bias_correction : bool = False ,
488496 ):
489- super ().__init__ (learning_rate = learning_rate , betas = betas , eps = eps )
497+ super ().__init__ (
498+ learning_rate = learning_rate ,
499+ betas = betas ,
500+ eps = eps ,
501+ bias_correction = bias_correction ,
502+ )
490503 self .weight_decay = weight_decay
491504
492505 def apply_single (self , gradient : mx .array , parameter : mx .array , state : dict ):
0 commit comments