Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Retire custom AdamWithDecay #2232

Open
Open
@stephenroller

Description

@stephenroller

We have a custom Adam optimizer that says it could be removed via pytorch 1.2:

# TODO: deprecate this entire class; it should be subsumed by TA as of pytorch 1.2
class AdamWithDecay(Optimizer):
"""
Adam with decay; mirror's HF's implementation.
:param lr:
learning rate
:param b1:
Adams b1. Default: 0.9
:param b2:
Adams b2. Default: 0.999
:param e:
Adams epsilon. Default: 1e-6
:param weight_decay:
Weight decay. Default: 0.01
:param max_grad_norm:
Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(
self,
params,
lr=required,
b1=0.9,
b2=0.999,
e=1e-6,
weight_decay=0.01,
max_grad_norm=1.0,
):
if lr is not required and lr < 0.0:
raise ValueError('Invalid learning rate: {} - should be >= 0.0'.format(lr))
if not 0.0 <= b1 < 1.0:
raise ValueError(
'Invalid b1 parameter: {} - should be in [0.0, 1.0['.format(b1)
)
if not 0.0 <= b2 < 1.0:
raise ValueError(
'Invalid b2 parameter: {} - should be in [0.0, 1.0['.format(b2)
)
if not e >= 0.0:
raise ValueError('Invalid epsilon value: {} - should be >= 0.0'.format(e))
defaults = dict(
lr=lr,
b1=b1,
b2=b2,
e=e,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
)
super(AdamWithDecay, self).__init__(params, defaults)
def step(self, closure=None):
"""
Perform a single optimization step.
:param closure:
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients, please '
'consider SparseAdam instead'
)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
lr = group['lr']
update_with_lr = lr * update
p.data.add_(-update_with_lr)
return loss

The task is to remove it and upgrade to the official pytorch version.

The helper function minimizes the pain required, though I'm wondering if we'll need to upgrade some state dicts, so make sure you run long tests.

def get_bert_optimizer(models, type_optimization, learning_rate, fp16=False):

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions