This repository was archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
This repository was archived by the owner on Nov 3, 2023. It is now read-only.
Retire custom AdamWithDecay #2232
Copy link
Copy link
Open
Labels
AgentsCode QualityHelp WantedMinordonotreapAvoid automatically marking as stale.Avoid automatically marking as stale.
Description
We have a custom Adam optimizer that says it could be removed via pytorch 1.2:
ParlAI/parlai/agents/bert_ranker/helpers.py
Lines 242 to 350 in 3a5de86
# TODO: deprecate this entire class; it should be subsumed by TA as of pytorch 1.2 | |
class AdamWithDecay(Optimizer): | |
""" | |
Adam with decay; mirror's HF's implementation. | |
:param lr: | |
learning rate | |
:param b1: | |
Adams b1. Default: 0.9 | |
:param b2: | |
Adams b2. Default: 0.999 | |
:param e: | |
Adams epsilon. Default: 1e-6 | |
:param weight_decay: | |
Weight decay. Default: 0.01 | |
:param max_grad_norm: | |
Maximum norm for the gradients (-1 means no clipping). Default: 1.0 | |
""" | |
def __init__( | |
self, | |
params, | |
lr=required, | |
b1=0.9, | |
b2=0.999, | |
e=1e-6, | |
weight_decay=0.01, | |
max_grad_norm=1.0, | |
): | |
if lr is not required and lr < 0.0: | |
raise ValueError('Invalid learning rate: {} - should be >= 0.0'.format(lr)) | |
if not 0.0 <= b1 < 1.0: | |
raise ValueError( | |
'Invalid b1 parameter: {} - should be in [0.0, 1.0['.format(b1) | |
) | |
if not 0.0 <= b2 < 1.0: | |
raise ValueError( | |
'Invalid b2 parameter: {} - should be in [0.0, 1.0['.format(b2) | |
) | |
if not e >= 0.0: | |
raise ValueError('Invalid epsilon value: {} - should be >= 0.0'.format(e)) | |
defaults = dict( | |
lr=lr, | |
b1=b1, | |
b2=b2, | |
e=e, | |
weight_decay=weight_decay, | |
max_grad_norm=max_grad_norm, | |
) | |
super(AdamWithDecay, self).__init__(params, defaults) | |
def step(self, closure=None): | |
""" | |
Perform a single optimization step. | |
:param closure: | |
A closure that reevaluates the model and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
if grad.is_sparse: | |
raise RuntimeError( | |
'Adam does not support sparse gradients, please ' | |
'consider SparseAdam instead' | |
) | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
# Exponential moving average of gradient values | |
state['next_m'] = torch.zeros_like(p.data) | |
# Exponential moving average of squared gradient values | |
state['next_v'] = torch.zeros_like(p.data) | |
next_m, next_v = state['next_m'], state['next_v'] | |
beta1, beta2 = group['b1'], group['b2'] | |
# Add grad clipping | |
if group['max_grad_norm'] > 0: | |
clip_grad_norm_(p, group['max_grad_norm']) | |
# Decay the first and second moment running average coefficient | |
# In-place operations to update the averages at the same time | |
next_m.mul_(beta1).add_(1 - beta1, grad) | |
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
update = next_m / (next_v.sqrt() + group['e']) | |
# Just adding the square of the weights to the loss function is *not* | |
# the correct way of using L2 regularization/weight decay with Adam, | |
# since that will interact with the m and v parameters in strange ways. | |
# | |
# Instead we want to decay the weights in a manner that doesn't interact | |
# with the m/v parameters. This is equivalent to adding the square | |
# of the weights to the loss with plain (non-momentum) SGD. | |
if group['weight_decay'] > 0.0: | |
update += group['weight_decay'] * p.data | |
lr = group['lr'] | |
update_with_lr = lr * update | |
p.data.add_(-update_with_lr) | |
return loss |
The task is to remove it and upgrade to the official pytorch version.
The helper function minimizes the pain required, though I'm wondering if we'll need to upgrade some state dicts, so make sure you run long tests.
ParlAI/parlai/agents/bert_ranker/helpers.py
Line 205 in 3a5de86
def get_bert_optimizer(models, type_optimization, learning_rate, fp16=False): |
Metadata
Metadata
Assignees
Labels
AgentsCode QualityHelp WantedMinordonotreapAvoid automatically marking as stale.Avoid automatically marking as stale.