Skip to content

Commit 4045edd

Browse files
authored
Merge pull request #41 from cyber-physical-systems-group/training/wandb-logs
[training](chore) Update W&B logs
2 parents bfe28f8 + 0cc559d commit 4045edd

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

pydentification/training/lightning/callbacks.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import abstractmethod
22
from bisect import bisect_right
33
from collections import Counter
4-
from typing import Any, Sequence
4+
from typing import Any, Literal, Sequence
55

66
import lightning.pytorch as pl
77

@@ -53,20 +53,20 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:
5353

5454
def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
5555
if self.verbose:
56-
print(f"StepAutoRegressionLengthScheduler: initial length = {trainer.datamodule.n_forward_time_steps}")
56+
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")
5757

5858
self.base_length = trainer.datamodule.n_forward_time_steps
5959

6060
def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:
6161
if self.base_length is None:
62-
raise RuntimeError("StepAutoRegressionLengthScheduler: base_length is None")
62+
raise RuntimeError("{self.__class__.__name__}: base_length is None!")
6363

6464
if trainer.current_epoch % self.step_size == 0:
6565
trainer.datamodule.n_forward_time_steps = self._get_closed_form_ar_length(trainer.current_epoch)
6666

6767
if self.verbose:
6868
print(
69-
f"StepAutoRegressionLengthScheduler: new length = {trainer.datamodule.n_forward_time_steps}"
69+
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
7070
f" at epoch {trainer.current_epoch}"
7171
)
7272

@@ -100,7 +100,7 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:
100100

101101
def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
102102
if self.verbose:
103-
print(f"MultiStepAutoRegressionLengthScheduler: initial length = {trainer.datamodule.n_forward_time_steps}")
103+
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")
104104

105105
self.base_length = trainer.datamodule.n_forward_time_steps
106106

@@ -112,6 +112,90 @@ def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:
112112

113113
if self.verbose:
114114
print(
115-
f"MultiStepAutoRegressionLengthScheduler: new length = {trainer.datamodule.n_forward_time_steps}"
115+
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
116116
f" at epoch {trainer.current_epoch} with milestones {list(self.milestones.keys())}"
117117
)
118+
119+
120+
class IncreaseAutoRegressionLengthOnPlateau(AbstractAutoRegressionLengthScheduler):
121+
"""
122+
Increases the length of auto-regression by factor once the monitored quantity stops improving.
123+
Works as ReduceLROnPlateau scheduler, but increasing the length (given as int!) instead of decaying learning rate.
124+
125+
:note: this callback changes the length after validation,
126+
at the end of epoch, unlike others, which do it on the start of epoch
127+
128+
Source reference: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
129+
"""
130+
131+
def __init__(
132+
self,
133+
monitor: str,
134+
patience: int,
135+
factor: int,
136+
threshold: float = 1e-4,
137+
threshold_mode: Literal["abs", "rel"] = "rel",
138+
max_length: int | None = None,
139+
verbose: bool = False,
140+
):
141+
"""
142+
:param monitor: quantity to be monitored given as key from callback_metrics dictionary of pl.Trainer
143+
:param patience: number of epochs with no improvement after which auto-regression length will be increased
144+
:param factor: factor by which to increase auto-regression length. new_length = old_length * factor
145+
:param threshold: threshold for measuring the new optimum, to only focus on significant changes
146+
:param threshold_mode: one of {"rel", "abs"}, defaults to "rel"
147+
:param max_length: maximum auto-regression length, defaults to None
148+
:param verbose: if True, prints the auto-regression length when it is changed
149+
"""
150+
super().__init__()
151+
152+
self.monitor = monitor
153+
self.patience = patience
154+
self.factor = factor
155+
156+
self.threshold = threshold
157+
self.threshold_mode = threshold_mode
158+
self.max_length = max_length
159+
self.verbose = verbose
160+
161+
self.best = float("inf")
162+
self.num_bad_epochs = 0
163+
164+
def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
165+
if self.verbose:
166+
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")
167+
168+
def is_better(self, current: float, best: float) -> bool:
169+
if self.threshold_mode == "rel":
170+
return current < best * (float(1) - self.threshold)
171+
172+
else: # self.threshold_mode == "abs":
173+
return current < best - self.threshold
174+
175+
def on_validation_epoch_end(self, trainer: pl.Trainer, _: Any) -> None:
176+
current = trainer.callback_metrics.get(self.monitor)
177+
if current is None:
178+
raise RuntimeError(f"{self.__class__.__name__}: metric {self.monitor} not found in callback_metrics!")
179+
180+
if self.is_better(current, self.best):
181+
self.best = current
182+
self.num_bad_epochs = 0
183+
else:
184+
self.num_bad_epochs += 1
185+
186+
if self.num_bad_epochs >= self.patience:
187+
new_length = trainer.datamodule.n_forward_time_steps * self.factor
188+
189+
if new_length > self.max_length:
190+
if self.verbose:
191+
print(f"{self.__class__.__name__}: maximum length reached, not increasing")
192+
return # exit function is new length is greater than maximum length
193+
194+
trainer.datamodule.n_forward_time_steps = new_length
195+
self.num_bad_epochs = 0
196+
197+
if self.verbose:
198+
print(
199+
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
200+
f" at epoch {trainer.current_epoch}"
201+
)

pydentification/training/lightning/prediction.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
7373
predictions = self.unroll_forward(batch, self.teacher_forcing)
7474

7575
loss = self.loss(predictions, y)
76-
self.log("training/train_loss", loss)
76+
self.log("trainer/train_loss", loss)
7777

7878
return loss
7979

@@ -83,17 +83,18 @@ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tenso
8383
predictions = self.unroll_forward(batch, teacher_forcing=False) # never use teacher forcing during validation
8484

8585
loss = self.loss(predictions, y)
86-
self.log("training/validation_loss", loss)
86+
self.log("trainer/validation_loss", loss)
8787

8888
return loss
8989

9090
def on_train_epoch_end(self):
91-
self.log("training/lr", self.trainer.optimizers[0].param_groups[0]["lr"])
91+
self.log("trainer/lr", self.trainer.optimizers[0].param_groups[0]["lr"])
92+
self.log("trainer/n_forward_time_steps", self.trainer.datamodule.n_forward_time_steps)
9293

9394
def predict_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, dataloader_idx: int = 0) -> Tensor:
9495
"""Requires using batch of training inputs and targets to know the number of time steps to predict"""
9596
return self.unroll_forward(batch, teacher_forcing=False) # never use teacher forcing during prediction
9697

9798
def configure_optimizers(self) -> dict[str, Any]:
98-
config = {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "training/validation_loss"}
99+
config = {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "trainer/validation_loss"}
99100
return {key: value for key, value in config.items() if value is not None} # remove None values

pydentification/training/lightning/simulation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,20 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
4343
x, y = batch
4444
y_hat = self.module(x) # type: ignore
4545
loss = self.loss(y_hat, y)
46-
self.log("training/train_loss", loss)
46+
self.log("trainer/train_loss", loss)
4747

4848
return loss
4949

5050
def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
5151
x, y = batch
5252
y_hat = self.module(x) # type: ignore
5353
loss = self.loss(y_hat, y)
54-
self.log("training/validation_loss", loss)
54+
self.log("trainer/validation_loss", loss)
5555

5656
return loss
5757

5858
def on_train_epoch_end(self):
59-
self.log("training/lr", self.trainer.optimizers[0].param_groups[0]["lr"])
59+
self.log("trainer/lr", self.trainer.optimizers[0].param_groups[0]["lr"])
6060

6161
def predict_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, _: int = 0) -> Tensor:
6262
"""
@@ -67,5 +67,5 @@ def predict_step(self, batch: tuple[Tensor, Tensor], batch_idx: int, _: int = 0)
6767
return self.module(x) # type: ignore
6868

6969
def configure_optimizers(self) -> dict[str, Any]:
70-
config = {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "training/validation_loss"}
70+
config = {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "trainer/validation_loss"}
7171
return {key: value for key, value in config.items() if value is not None} # remove None values

0 commit comments

Comments
 (0)