Skip to content

Commit 93f4930

Browse files
[optim] Add PolynomialLrSchedule and InvSqrtDecayLrSchedule
Add two new LR schedule types to Levanter's optim config for the LR schedule sweep experiments described in issue #4082. PolynomialLrSchedule wraps optax.polynomial_schedule with configurable power parameter (power=1 linear, power=2 quadratic, power=0.5 sqrt). InvSqrtDecayLrSchedule implements lr/sqrt(1+c*t/T) with configurable decay constant, providing a schedule that never reaches zero. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c16d601 commit 93f4930

2 files changed

Lines changed: 154 additions & 0 deletions

File tree

lib/levanter/src/levanter/optim/config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,49 @@ def build(self, ctx: LrScheduleContext):
7878
return _inv_decay_schedule(ctx.learning_rate, ctx.min_lr, ctx.decay_steps)
7979

8080

81+
@LrSchedule.register_subclass("polynomial")
82+
@dataclass(frozen=True)
83+
class PolynomialLrSchedule(LrSchedule):
84+
"""Polynomial decay: lr * (1 - t/T)^power, reaching min_lr at step T.
85+
86+
Wraps optax.polynomial_schedule. Power=1 gives linear decay, power=2 gives
87+
quadratic decay (drops LR more aggressively early), power=0.5 gives sqrt decay
88+
(holds LR higher early, drops sharply at end).
89+
"""
90+
91+
power: float = 2.0
92+
93+
def build(self, ctx: LrScheduleContext):
94+
return optax.polynomial_schedule(
95+
init_value=ctx.learning_rate,
96+
end_value=ctx.min_lr,
97+
power=self.power,
98+
transition_steps=ctx.decay_steps,
99+
)
100+
101+
102+
@LrSchedule.register_subclass("inv_sqrt_decay")
103+
@dataclass(frozen=True)
104+
class InvSqrtDecayLrSchedule(LrSchedule):
105+
"""Inverse sqrt decay: lr / sqrt(1 + c * t).
106+
107+
Unlike InvSqrtLrSchedule (which uses a fixed timescale relative to warmup),
108+
this schedule decays from peak LR using a configurable constant `c` that
109+
controls how fast the LR drops. The LR never reaches zero.
110+
"""
111+
112+
decay_constant: float = 28.6
113+
114+
def build(self, ctx: LrScheduleContext):
115+
c = self.decay_constant
116+
lr = ctx.learning_rate
117+
118+
def schedule(count):
119+
return lr / jnp.sqrt(1.0 + c * count / ctx.decay_steps)
120+
121+
return schedule
122+
123+
81124
@LrSchedule.register_subclass("power")
82125
@dataclass(frozen=True)
83126
class PowerLrSchedule(LrSchedule):

lib/levanter/tests/test_optimizer_config.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,117 @@ def test_wsds_schedule_with_cycle_points():
248248
assert sched_fn(971) < 1e-3
249249

250250

251+
def test_polynomial_schedule_quadratic():
252+
"""Quadratic decay: (1-t)^2 shape via PolynomialLrSchedule."""
253+
from levanter.optim.config import PolynomialLrSchedule
254+
255+
optimizer = AdamConfig(
256+
learning_rate=1e-3,
257+
weight_decay=0.0,
258+
warmup=0.1,
259+
min_lr_ratio=0.0,
260+
lr_schedule=PolynomialLrSchedule(power=2.0),
261+
)
262+
263+
sched_fn = optimizer.lr_scheduler(1000)
264+
265+
# Warmup phase
266+
assert np.isclose(sched_fn(0), 0.0)
267+
assert np.isclose(sched_fn(100), 1e-3)
268+
269+
# Decay phase: at midpoint (t=450 into 900-step decay), LR = 1e-3 * (1 - 450/900)^2 = 0.25e-3
270+
assert np.isclose(sched_fn(550), 0.25e-3, atol=1e-6)
271+
272+
# End of decay
273+
assert np.isclose(sched_fn(999), 0.0, atol=1e-5)
274+
275+
276+
def test_polynomial_schedule_linear():
277+
"""Power=1 should match linear decay."""
278+
from levanter.optim.config import PolynomialLrSchedule
279+
280+
optimizer = AdamConfig(
281+
learning_rate=1e-3,
282+
weight_decay=0.0,
283+
warmup=0.0,
284+
min_lr_ratio=0.0,
285+
lr_schedule=PolynomialLrSchedule(power=1.0),
286+
)
287+
288+
sched_fn = optimizer.lr_scheduler(100)
289+
290+
assert np.isclose(sched_fn(0), 1e-3)
291+
assert np.isclose(sched_fn(50), 0.5e-3, atol=1e-6)
292+
assert np.isclose(sched_fn(100), 0.0, atol=1e-6)
293+
294+
295+
def test_polynomial_schedule_sqrt():
296+
"""Power=0.5 (sqrt decay) holds LR higher early, drops faster at end."""
297+
from levanter.optim.config import PolynomialLrSchedule
298+
299+
optimizer = AdamConfig(
300+
learning_rate=1e-3,
301+
weight_decay=0.0,
302+
warmup=0.0,
303+
min_lr_ratio=0.0,
304+
lr_schedule=PolynomialLrSchedule(power=0.5),
305+
)
306+
307+
sched_fn = optimizer.lr_scheduler(100)
308+
309+
# At midpoint: (1-0.5)^0.5 ≈ 0.707
310+
assert np.isclose(sched_fn(50), 1e-3 * 0.5**0.5, atol=1e-5)
311+
assert np.isclose(sched_fn(100), 0.0, atol=1e-6)
312+
313+
314+
def test_polynomial_schedule_with_min_lr():
315+
"""Polynomial decay with a floor (min_lr_ratio > 0)."""
316+
from levanter.optim.config import PolynomialLrSchedule
317+
318+
optimizer = AdamConfig(
319+
learning_rate=1e-3,
320+
weight_decay=0.0,
321+
warmup=0.0,
322+
min_lr_ratio=0.05,
323+
lr_schedule=PolynomialLrSchedule(power=2.0),
324+
)
325+
326+
sched_fn = optimizer.lr_scheduler(100)
327+
328+
# End of decay should reach min_lr = 0.05 * 1e-3
329+
assert np.isclose(sched_fn(100), 0.05e-3, atol=1e-6)
330+
331+
332+
def test_inv_sqrt_decay_lr_schedule():
333+
"""InvSqrtDecayLrSchedule: lr / sqrt(1 + c * t / T)."""
334+
from levanter.optim.config import InvSqrtDecayLrSchedule
335+
336+
optimizer = AdamConfig(
337+
learning_rate=1e-3,
338+
weight_decay=0.0,
339+
warmup=0.0,
340+
min_lr_ratio=0.0,
341+
lr_schedule=InvSqrtDecayLrSchedule(decay_constant=28.6),
342+
)
343+
344+
sched_fn = optimizer.lr_scheduler(1000)
345+
346+
# At t=0, lr = 1e-3 / sqrt(1) = 1e-3
347+
assert np.isclose(sched_fn(0), 1e-3)
348+
349+
# Monotonically decreasing
350+
assert sched_fn(100) < sched_fn(0)
351+
assert sched_fn(500) < sched_fn(100)
352+
assert sched_fn(999) < sched_fn(500)
353+
354+
# At t=T, lr = 1e-3 / sqrt(1 + 28.6) ≈ 1e-3 / 5.44 ≈ 0.000184
355+
expected_end = 1e-3 / np.sqrt(1 + 28.6)
356+
assert np.isclose(sched_fn(1000), expected_end, atol=1e-6)
357+
358+
# Never reaches zero
359+
assert sched_fn(1000) > 0
360+
361+
251362
def test_warmup_longer_than_run_does_not_jump():
252363
optimizer = AdamConfig(
253364
learning_rate=3e-3,

0 commit comments

Comments
 (0)