diff --git a/torch_optimizer/adabound.py b/torch_optimizer/adabound.py index b7e0ef4..93740ec 100644 --- a/torch_optimizer/adabound.py +++ b/torch_optimizer/adabound.py @@ -175,3 +175,176 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: p.data.add_(-step_size) return loss + + +class AdaBoundW(Optimizer): + r"""Implements AdaBound algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101) + + It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of + Learning Rate`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-3) + betas: coefficients used for computing running averages of gradient + and its square (default: (0.9, 0.999)) + final_lr: final (SGD) learning rate (default: 0.1) + gamma: convergence speed of the bound functions + (default: 1e-3) + eps: term added to the denominator to improve numerical stability + (default: 1e-8) + weight_decay: weight decay (L2 penalty) (default: 0) + amsbound: whether to use the AMSBound variant of this algorithm + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.AdaBoundW(model.parameters(), lr=0.1, weight_decay=0.001) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ https://arxiv.org/abs/1902.09843 + + Note: + Reference code: https://github.com/Luolc/AdaBound + """ + + def __init__( + self, + params: Params, + lr: float = 1e-3, + betas: Betas2 = (0.9, 0.999), + final_lr: float = 0.1, + gamma: float = 1e-3, + eps: float = 1e-8, + weight_decay: float = 0, + amsbound: bool = False, + ) -> None: + if lr <= 0.0: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if eps < 0.0: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 0: {}'.format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 1: {}'.format(betas[1]) + ) + if final_lr < 0.0: + raise ValueError( + 'Invalid final learning rate: {}'.format(final_lr) + ) + if not 0.0 <= gamma < 1.0: + raise ValueError('Invalid gamma parameter: {}'.format(gamma)) + if weight_decay < 0: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay) + ) + defaults = dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ) + super(AdaBoundW, self).__init__(params, defaults) + self.base_lrs = [group['lr'] for group in self.param_groups] + + def __setstate__(self, state: State) -> None: + super(AdaBound, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsbound', False) + + def step(self, closure: OptLossClosure = None) -> OptFloat: + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group, base_lr in zip(self.param_groups, self.base_lrs): + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.mul_(1 - base_lr * group['weight_decay']) + + grad = p.grad.data + if grad.is_sparse: + msg = ( + 'AdaBound does not support sparse gradients, ' + 'please consider SparseAdam instead' + ) + raise RuntimeError(msg) + amsbound = group['amsbound'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsbound: + # Maintains max of all exp. moving avg. of + # sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsbound: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # if group['weight_decay'] != 0: + # grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsbound: + # Maintains the maximum of all 2nd moment running + # avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = ( + group['lr'] + * math.sqrt(bias_correction2) + / bias_correction1 + ) + + # Applies bounds on actual learning rate + # lr_scheduler cannot affect final_lr, this is a workaround + # to apply lr decay + final_lr = group['final_lr'] * group['lr'] / base_lr + lower_bound = final_lr * ( + 1 - 1 / (group['gamma'] * state['step'] + 1) + ) + upper_bound = final_lr * ( + 1 + 1 / (group['gamma'] * state['step']) + ) + step_size = torch.full_like(denom, step_size) + step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_( + exp_avg + ) + + p.data.add_(-step_size) + return loss \ No newline at end of file