11from abc import abstractmethod
22from bisect import bisect_right
33from collections import Counter
4- from typing import Any , Sequence
4+ from typing import Any , Literal , Sequence
55
66import lightning .pytorch as pl
77
@@ -53,20 +53,20 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:
5353
5454 def on_train_start (self , trainer : pl .Trainer , _ : Any ) -> None :
5555 if self .verbose :
56- print (f"StepAutoRegressionLengthScheduler : initial length = { trainer .datamodule .n_forward_time_steps } " )
56+ print (f"{ self . __class__ . __name__ } : initial length = { trainer .datamodule .n_forward_time_steps } " )
5757
5858 self .base_length = trainer .datamodule .n_forward_time_steps
5959
6060 def on_train_epoch_start (self , trainer : pl .Trainer , _ : Any ) -> None :
6161 if self .base_length is None :
62- raise RuntimeError ("StepAutoRegressionLengthScheduler : base_length is None" )
62+ raise RuntimeError ("{self.__class__.__name__} : base_length is None! " )
6363
6464 if trainer .current_epoch % self .step_size == 0 :
6565 trainer .datamodule .n_forward_time_steps = self ._get_closed_form_ar_length (trainer .current_epoch )
6666
6767 if self .verbose :
6868 print (
69- f"StepAutoRegressionLengthScheduler : new length = { trainer .datamodule .n_forward_time_steps } "
69+ f"{ self . __class__ . __name__ } : new length = { trainer .datamodule .n_forward_time_steps } "
7070 f" at epoch { trainer .current_epoch } "
7171 )
7272
@@ -100,7 +100,7 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:
100100
101101 def on_train_start (self , trainer : pl .Trainer , _ : Any ) -> None :
102102 if self .verbose :
103- print (f"MultiStepAutoRegressionLengthScheduler : initial length = { trainer .datamodule .n_forward_time_steps } " )
103+ print (f"{ self . __class__ . __name__ } : initial length = { trainer .datamodule .n_forward_time_steps } " )
104104
105105 self .base_length = trainer .datamodule .n_forward_time_steps
106106
@@ -112,6 +112,90 @@ def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:
112112
113113 if self .verbose :
114114 print (
115- f"MultiStepAutoRegressionLengthScheduler : new length = { trainer .datamodule .n_forward_time_steps } "
115+ f"{ self . __class__ . __name__ } : new length = { trainer .datamodule .n_forward_time_steps } "
116116 f" at epoch { trainer .current_epoch } with milestones { list (self .milestones .keys ())} "
117117 )
118+
119+
120+ class IncreaseAutoRegressionLengthOnPlateau (AbstractAutoRegressionLengthScheduler ):
121+ """
122+ Increases the length of auto-regression by factor once the monitored quantity stops improving.
123+ Works as ReduceLROnPlateau scheduler, but increasing the length (given as int!) instead of decaying learning rate.
124+
125+ :note: this callback changes the length after validation,
126+ at the end of epoch, unlike others, which do it on the start of epoch
127+
128+ Source reference: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
129+ """
130+
131+ def __init__ (
132+ self ,
133+ monitor : str ,
134+ patience : int ,
135+ factor : int ,
136+ threshold : float = 1e-4 ,
137+ threshold_mode : Literal ["abs" , "rel" ] = "rel" ,
138+ max_length : int | None = None ,
139+ verbose : bool = False ,
140+ ):
141+ """
142+ :param monitor: quantity to be monitored given as key from callback_metrics dictionary of pl.Trainer
143+ :param patience: number of epochs with no improvement after which auto-regression length will be increased
144+ :param factor: factor by which to increase auto-regression length. new_length = old_length * factor
145+ :param threshold: threshold for measuring the new optimum, to only focus on significant changes
146+ :param threshold_mode: one of {"rel", "abs"}, defaults to "rel"
147+ :param max_length: maximum auto-regression length, defaults to None
148+ :param verbose: if True, prints the auto-regression length when it is changed
149+ """
150+ super ().__init__ ()
151+
152+ self .monitor = monitor
153+ self .patience = patience
154+ self .factor = factor
155+
156+ self .threshold = threshold
157+ self .threshold_mode = threshold_mode
158+ self .max_length = max_length
159+ self .verbose = verbose
160+
161+ self .best = float ("inf" )
162+ self .num_bad_epochs = 0
163+
164+ def on_train_start (self , trainer : pl .Trainer , _ : Any ) -> None :
165+ if self .verbose :
166+ print (f"{ self .__class__ .__name__ } : initial length = { trainer .datamodule .n_forward_time_steps } " )
167+
168+ def is_better (self , current : float , best : float ) -> bool :
169+ if self .threshold_mode == "rel" :
170+ return current < best * (float (1 ) - self .threshold )
171+
172+ else : # self.threshold_mode == "abs":
173+ return current < best - self .threshold
174+
175+ def on_validation_epoch_end (self , trainer : pl .Trainer , _ : Any ) -> None :
176+ current = trainer .callback_metrics .get (self .monitor )
177+ if current is None :
178+ raise RuntimeError (f"{ self .__class__ .__name__ } : metric { self .monitor } not found in callback_metrics!" )
179+
180+ if self .is_better (current , self .best ):
181+ self .best = current
182+ self .num_bad_epochs = 0
183+ else :
184+ self .num_bad_epochs += 1
185+
186+ if self .num_bad_epochs >= self .patience :
187+ new_length = trainer .datamodule .n_forward_time_steps * self .factor
188+
189+ if new_length > self .max_length :
190+ if self .verbose :
191+ print (f"{ self .__class__ .__name__ } : maximum length reached, not increasing" )
192+ return # exit function is new length is greater than maximum length
193+
194+ trainer .datamodule .n_forward_time_steps = new_length
195+ self .num_bad_epochs = 0
196+
197+ if self .verbose :
198+ print (
199+ f"{ self .__class__ .__name__ } : new length = { trainer .datamodule .n_forward_time_steps } "
200+ f" at epoch { trainer .current_epoch } "
201+ )
0 commit comments