44"""
55
66from torch .optim import Optimizer
7- from torch .optim .lr_scheduler import _LRScheduler
7+ from torch .optim .lr_scheduler import _LRScheduler , ReduceLROnPlateau
88
99__all__ = ['VERSION' , 'WarmUpScheduler' ]
1010
11- VERSION = '0.1.1 '
11+ VERSION = '0.1.2 '
1212
1313
1414class WarmUpScheduler (object ):
@@ -44,12 +44,12 @@ def __init__(self, optimizer, lr_scheduler, warmup_steps: int, warmup_start_lr,
4444
4545 # Attach optimizer
4646 if not isinstance (optimizer , Optimizer ):
47- raise TypeError (f'{ type (optimizer ).__name__ } is not an Optimizer' )
47+ raise TypeError (f'{ type (optimizer ).__name__ } is not an Optimizer in pytorch ' )
4848 self .optimizer = optimizer
4949
5050 # Attach lr_scheduler
51- if not isinstance (lr_scheduler , _LRScheduler ):
52- raise TypeError (f'{ type (lr_scheduler ).__name__ } is not a _LRScheduler ' )
51+ if not isinstance (lr_scheduler , ( _LRScheduler , ReduceLROnPlateau ) ):
52+ raise TypeError (f'{ type (lr_scheduler ).__name__ } is not a lr_scheduler in pytorch ' )
5353 self .lr_scheduler = lr_scheduler
5454
5555 # check whether attribute initial_lr in optimizer.param_group
@@ -77,6 +77,7 @@ def __init__(self, optimizer, lr_scheduler, warmup_steps: int, warmup_start_lr,
7777 self ._step_count = 0
7878 self ._last_lr = None
7979 self .__warmup_done = False
80+ self .__is_ReduceLROnPlateau = isinstance (lr_scheduler , ReduceLROnPlateau )
8081 self .verbose = verbose
8182
8283 self .step ()
@@ -136,12 +137,15 @@ def _new_epoch(self):
136137 r"""Return whether is a new epoch started now"""
137138 return self .last_step % self .len_loader == 0
138139
139- def _step (self , epoch ):
140+ def _step (self , epoch , metrics ):
140141 r"""For warmup_scheduler_pytorch and lr_scheduler step once"""
141142 if self .__warmup_done and self ._new_epoch :
142- self .lr_scheduler .step ()
143+ if self .__is_ReduceLROnPlateau :
144+ self .lr_scheduler .step (metrics , epoch )
145+ else :
146+ self .lr_scheduler .step (epoch )
143147
144- elif not self .__warmup_done and self .last_step <= self .warmup_steps :
148+ elif ( not self .__warmup_done ) and ( self .last_step <= self .warmup_steps ) :
145149 values = self .get_warmup_lr ()
146150
147151 if self .last_step >= self .warmup_steps :
@@ -151,29 +155,29 @@ def _step(self, epoch):
151155 param_group ['lr' ] = lr
152156 self .print_lr (self .verbose , idx , lr , epoch )
153157
154- def step (self , step = None , epoch = None ):
158+ def step (self , metrics = None , step = None , epoch = None ):
155159 self ._step_count += 1
156160
157161 if step is None and epoch is None :
158162 self .last_step += 1
159163 if self ._new_epoch :
160164 self .last_epoch += 1
161- self ._step (epoch )
165+ self ._step (epoch , metrics )
162166
163167 elif step is not None and epoch is None :
164168 self .last_step = step
165169 self .last_epoch = step // self .len_loader
166- self ._step (epoch )
170+ self ._step (epoch , metrics )
167171
168172 elif step is None and epoch is not None :
169173 self .last_step = epoch * self .len_loader
170174 self .last_epoch = epoch
171- self ._step (epoch )
175+ self ._step (epoch , metrics )
172176
173177 else : # if step and epoch
174178 # step is relative to epoch only here
175179 self .last_step = step + epoch * self .len_loader
176180 self .last_epoch = epoch
177- self ._step (epoch )
181+ self ._step (epoch , metrics )
178182
179183 self ._last_lr = [group ['lr' ] for group in self .optimizer .param_groups ]
0 commit comments