Skip to content

Commit f90a352

Browse files
author
LEFTeyes
committed
upgrade for ReduceLROnPlateau
1 parent d60d283 commit f90a352

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

src/warmup_scheduler_pytorch/warmup_module.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
"""
55

66
from torch.optim import Optimizer
7-
from torch.optim.lr_scheduler import _LRScheduler
7+
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
88

99
__all__ = ['VERSION', 'WarmUpScheduler']
1010

11-
VERSION = '0.1.1'
11+
VERSION = '0.1.2'
1212

1313

1414
class WarmUpScheduler(object):
@@ -44,12 +44,12 @@ def __init__(self, optimizer, lr_scheduler, warmup_steps: int, warmup_start_lr,
4444

4545
# Attach optimizer
4646
if not isinstance(optimizer, Optimizer):
47-
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
47+
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer in pytorch')
4848
self.optimizer = optimizer
4949

5050
# Attach lr_scheduler
51-
if not isinstance(lr_scheduler, _LRScheduler):
52-
raise TypeError(f'{type(lr_scheduler).__name__} is not a _LRScheduler')
51+
if not isinstance(lr_scheduler, (_LRScheduler, ReduceLROnPlateau)):
52+
raise TypeError(f'{type(lr_scheduler).__name__} is not a lr_scheduler in pytorch')
5353
self.lr_scheduler = lr_scheduler
5454

5555
# check whether attribute initial_lr in optimizer.param_group
@@ -77,6 +77,7 @@ def __init__(self, optimizer, lr_scheduler, warmup_steps: int, warmup_start_lr,
7777
self._step_count = 0
7878
self._last_lr = None
7979
self.__warmup_done = False
80+
self.__is_ReduceLROnPlateau = isinstance(lr_scheduler, ReduceLROnPlateau)
8081
self.verbose = verbose
8182

8283
self.step()
@@ -136,12 +137,15 @@ def _new_epoch(self):
136137
r"""Return whether is a new epoch started now"""
137138
return self.last_step % self.len_loader == 0
138139

139-
def _step(self, epoch):
140+
def _step(self, epoch, metrics):
140141
r"""For warmup_scheduler_pytorch and lr_scheduler step once"""
141142
if self.__warmup_done and self._new_epoch:
142-
self.lr_scheduler.step()
143+
if self.__is_ReduceLROnPlateau:
144+
self.lr_scheduler.step(metrics, epoch)
145+
else:
146+
self.lr_scheduler.step(epoch)
143147

144-
elif not self.__warmup_done and self.last_step <= self.warmup_steps:
148+
elif (not self.__warmup_done) and (self.last_step <= self.warmup_steps):
145149
values = self.get_warmup_lr()
146150

147151
if self.last_step >= self.warmup_steps:
@@ -151,29 +155,29 @@ def _step(self, epoch):
151155
param_group['lr'] = lr
152156
self.print_lr(self.verbose, idx, lr, epoch)
153157

154-
def step(self, step=None, epoch=None):
158+
def step(self, metrics=None, step=None, epoch=None):
155159
self._step_count += 1
156160

157161
if step is None and epoch is None:
158162
self.last_step += 1
159163
if self._new_epoch:
160164
self.last_epoch += 1
161-
self._step(epoch)
165+
self._step(epoch, metrics)
162166

163167
elif step is not None and epoch is None:
164168
self.last_step = step
165169
self.last_epoch = step // self.len_loader
166-
self._step(epoch)
170+
self._step(epoch, metrics)
167171

168172
elif step is None and epoch is not None:
169173
self.last_step = epoch * self.len_loader
170174
self.last_epoch = epoch
171-
self._step(epoch)
175+
self._step(epoch, metrics)
172176

173177
else: # if step and epoch
174178
# step is relative to epoch only here
175179
self.last_step = step + epoch * self.len_loader
176180
self.last_epoch = epoch
177-
self._step(epoch)
181+
self._step(epoch, metrics)
178182

179183
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

0 commit comments

Comments
 (0)