-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Description & Motivation
Feature Request
Currently, PyTorch Lightning only supports interval: "step" (per optimizer step) and interval: "epoch" (per epoch) for learning rate schedulers. This limitation creates problems when using gradient accumulation with varying accumulation factors during training.
Motivation
Despite Lightning's Trainer explicitly supporting adaptive accumulate_grad_batches, there is no way to schedule learning rates or other hyperparameters per batch. This makes runs with varying accumulation difficult to compare, as they should be scheduled according to tokens/batches seen rather than optimizer steps taken. This also makes investigating training dynamics and batch-size-dependent effects more difficult, as research requires per-batch granularity rather than per-optimizer-step granularity.
Pitch
Proposal: 'batch' interval
Add interval: "batch" as a supported option in configure_optimizers() when defining schedules under the "interval" argument.
def configure_optimizers(self):
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "batch",
"frequency": 1,
},
}This then advances on every batch step instead of every optimizer step
Alternatives
Alternative: Inject batch number
The best alternative I found is that users may override lr_scheduler_step() in LightningModule and manually track batch indices on the class, then inject them into the schedule:
def on_train_batch_start(self, batch, batch_idx):
self.batch_num = batch_idx
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.batch_num)This is unintuitive, breaks Lightning's declarative config pattern, and requires understanding Lightning's internal stepping mechanism. Lightning should support batch stepping as part of auto training instead.
Additional context
Feel free to consult this article for an example of a process where it would have been useful (jump down to the LightningModule Integration section): Article
I would be happy to try implementing it myself if anyone is interested, but would like to know testing requirements and any dependencies. I have preliminarily identified "TrainEpochLoop" as the class in need of changes.
cc @lantiga