Skip to content

Commit c043815

Browse files
committed
fix for edge case lr lambda handling surfaced w/ PT 2.10
1 parent acf79c3 commit c043815

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/finetuning_scheduler/fts_supporters.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,8 +1595,13 @@ def add_optimizer_groups(
15951595
scheduler.min_lrs.extend([scheduler.min_lrs[0]] * added_pgs) # type: ignore[attr-defined]
15961596
else:
15971597
scheduler.base_lrs.extend([orig_lr_factor] * added_pgs)
1598-
if hasattr(scheduler, "lr_lambdas"):
1599-
scheduler.lr_lambdas.extend([scheduler.lr_lambdas[-1]] * added_pgs)
1598+
if hasattr(scheduler, "lr_lambdas") and scheduler.lr_lambdas:
1599+
# due to PyTorch lr scheduler state_dict peculiarities wrt lr_lambdas, lr_lambdas may
1600+
# already be pg-aligned (since lr_lambdas are only conditionally saved/restored) see:
1601+
# https://bit.ly/lr_lambda_state_dict_special_handling
1602+
lambdas_to_sync = max(len(scheduler.base_lrs) - len(scheduler.lr_lambdas), 0)
1603+
if lambdas_to_sync:
1604+
scheduler.lr_lambdas.extend([scheduler.lr_lambdas[-1]] * lambdas_to_sync)
16001605
else:
16011606
_ = ScheduleImplMixin._add_groups(no_decay, optimizer, module, thawed_pl, phase_lr)
16021607

src/finetuning_scheduler/strategy_adapters/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def _clean_optim_lr_pgs(trainer: Trainer) -> List:
224224
lrs_cfg.scheduler.last_epoch = -1 # type: ignore[union-attr]
225225
if not isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
226226
lrs_cfg.scheduler.base_lrs = []
227+
# if hasattr(lrs_cfg.scheduler, "lr_lambdas"):
228+
# lrs_cfg.scheduler.lr_lambdas = []
227229
return orig_num_pgs
228230

229231
def _reconfigure_optimizer_for_phase0(self, trainer: Trainer) -> None:
@@ -250,21 +252,23 @@ def _reconfigure_lrs_for_phase0(self, trainer: Trainer, orig_num_pgs: List) -> N
250252
Args:
251253
trainer (Trainer): The :external+pl:class:`~lightning.pytorch.trainer.trainer.Trainer` object.
252254
orig_num_pgs (List): A list of the number of parameter groups pruned for each optimizer (since only a single
253-
optimizer is currently supported by FTS, this list will have only a single element in this verison.)
255+
optimizer is currently supported by FTS, this list will have only a single element in this version.)
254256
"""
255257
# since we may have added parameter groups (e.g. implementing ``no_decay`` for user), we need to reinitialize
256258
# certain lr_scheduler variables (including type-dependent ones like ``min_lrs`` and ``lr_lambdas``)
257259
if trainer.lr_scheduler_configs:
258260
for lrs_cfg in trainer.lr_scheduler_configs:
261+
# if hasattr(lrs_cfg.scheduler, "lr_lambdas"):
262+
# lrs_cfg.scheduler.lr_lambdas = lrs_cfg.scheduler.lr_lambdas[orig_num_pgs[0] :]
259263
if not isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
260264
lrs_cfg.scheduler._initial_step()
261265
lrs_cfg.scheduler._last_lr = [ # type: ignore[union-attr]
262266
group["lr"] for group in lrs_cfg.scheduler.optimizer.param_groups
263267
]
264268
if isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
265269
lrs_cfg.scheduler.min_lrs = lrs_cfg.scheduler.min_lrs[orig_num_pgs[0] :]
266-
elif hasattr(lrs_cfg.scheduler, "lr_lambdas"):
267-
lrs_cfg.scheduler.lr_lambdas = lrs_cfg.scheduler.lr_lambdas[orig_num_pgs[0] :]
270+
# elif hasattr(lrs_cfg.scheduler, "lr_lambdas"):
271+
# lrs_cfg.scheduler.lr_lambdas = lrs_cfg.scheduler.lr_lambdas[orig_num_pgs[0] :]
268272

269273
def phase0_optimizer_override(self) -> None:
270274
"""Reconfigure the user-configured optimizer (configured via `configure_optimizers`) to optimize the

0 commit comments

Comments
 (0)