Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Retire custom AdamWithDecay #2232

Open
stephenroller opened this issue Dec 2, 2019 · 1 comment
Open

Retire custom AdamWithDecay #2232

stephenroller opened this issue Dec 2, 2019 · 1 comment
Labels

Comments

@stephenroller
Copy link
Contributor

We have a custom Adam optimizer that says it could be removed via pytorch 1.2:

# TODO: deprecate this entire class; it should be subsumed by TA as of pytorch 1.2
class AdamWithDecay(Optimizer):
"""
Adam with decay; mirror's HF's implementation.
:param lr:
learning rate
:param b1:
Adams b1. Default: 0.9
:param b2:
Adams b2. Default: 0.999
:param e:
Adams epsilon. Default: 1e-6
:param weight_decay:
Weight decay. Default: 0.01
:param max_grad_norm:
Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(
self,
params,
lr=required,
b1=0.9,
b2=0.999,
e=1e-6,
weight_decay=0.01,
max_grad_norm=1.0,
):
if lr is not required and lr < 0.0:
raise ValueError('Invalid learning rate: {} - should be >= 0.0'.format(lr))
if not 0.0 <= b1 < 1.0:
raise ValueError(
'Invalid b1 parameter: {} - should be in [0.0, 1.0['.format(b1)
)
if not 0.0 <= b2 < 1.0:
raise ValueError(
'Invalid b2 parameter: {} - should be in [0.0, 1.0['.format(b2)
)
if not e >= 0.0:
raise ValueError('Invalid epsilon value: {} - should be >= 0.0'.format(e))
defaults = dict(
lr=lr,
b1=b1,
b2=b2,
e=e,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
)
super(AdamWithDecay, self).__init__(params, defaults)
def step(self, closure=None):
"""
Perform a single optimization step.
:param closure:
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients, please '
'consider SparseAdam instead'
)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
lr = group['lr']
update_with_lr = lr * update
p.data.add_(-update_with_lr)
return loss

The task is to remove it and upgrade to the official pytorch version.

The helper function minimizes the pain required, though I'm wondering if we'll need to upgrade some state dicts, so make sure you run long tests.

def get_bert_optimizer(models, type_optimization, learning_rate, fp16=False):

@github-actions
Copy link

github-actions bot commented Jun 2, 2020

This issue has not had activity in 30 days. Marking as stale.

@github-actions github-actions bot closed this as completed Jun 2, 2020
@stephenroller stephenroller added donotreap Avoid automatically marking as stale. and removed stale-issue labels Jun 2, 2020
@stephenroller stephenroller reopened this Jun 2, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

2 participants