Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 0bdb91f

Browse files
committed
Fix relationship between optimizer and LR scheduler.
- Always initialize the LR scheduler before the first step. - Correctly call LR scheduler only once per iteration. See pytorch/pytorch#20124 Note: there is still a problem if you set update_interval as a number of Epochs.
1 parent 4260d9e commit 0bdb91f

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/refiners/training_utils/clock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self,
2323
training_duration: TimeValue,
2424
gradient_accumulation: Step,
25-
lr_scheduler_interval: TimeValue,
25+
lr_scheduler_interval: Iteration | Epoch,
2626
verbose: bool = True,
2727
) -> None:
2828
self.training_duration = training_duration

src/refiners/training_utils/trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __init__(self, config: ConfigType) -> None:
130130
self._load_models()
131131
self._call_callbacks(event_name="on_init_end")
132132

133+
# Ensure the lr_scheduler is initialized before calling `step` on the optimizer.
134+
# See `patch_track_step_called` in LRScheduler constructor.
135+
assert self.lr_scheduler
136+
133137
@register_callback()
134138
def clock(self, config: ClockConfig) -> TrainingClock:
135139
return TrainingClock(
@@ -299,10 +303,10 @@ def backward(self) -> None:
299303
self.optimizer.step()
300304
self.optimizer.zero_grad()
301305
self._call_callbacks(event_name="on_optimizer_step_end")
302-
if self.clock.is_due(self.config.lr_scheduler.update_interval):
303-
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
304-
self.lr_scheduler.step()
305-
self._call_callbacks(event_name="on_lr_scheduler_step_end")
306+
if self.clock.is_due(self.config.lr_scheduler.update_interval):
307+
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
308+
self.lr_scheduler.step()
309+
self._call_callbacks(event_name="on_lr_scheduler_step_end")
306310

307311
def step(self, batch: Batch) -> None:
308312
"""Perform a single training step."""

0 commit comments

Comments
 (0)