Skip to content

Commit d8b1a92

Browse files
author
Flax Authors
committed
Merge pull request #4351 from jlperla:lbfgs_support
PiperOrigin-RevId: 692253971
2 parents 591cd40 + 342adde commit d8b1a92

File tree

2 files changed

+106
-3
lines changed

2 files changed

+106
-3
lines changed

flax/nnx/training/optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
194194
self.wrt = wrt
195195

196-
def update(self, grads):
196+
def update(self, grads, **kwargs):
197197
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
198198
The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
199199
gradients are with respect to the same :class:`Variable` types as defined in
@@ -249,14 +249,16 @@ def update(self, grads):
249249
250250
Args:
251251
grads: the gradients derived from ``nnx.grad``.
252+
**kwargs: additional keyword arguments passed to the tx.update, to support
253+
``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``.
252254
"""
253255
params = nnx.state(self.model, self.wrt)
254256
opt_state = _opt_state_variables_to_state(self.opt_state)
255257

256-
updates, new_opt_state = self.tx.update(grads, opt_state, params)
258+
updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs)
257259
new_params = optax.apply_updates(params, updates)
258260
assert isinstance(new_params, nnx.State)
259261

260262
self.step.value += 1
261263
nnx.update(self.model, new_params)
262-
_update_opt_state(self.opt_state, new_opt_state)
264+
_update_opt_state(self.opt_state, new_opt_state)

tests/nnx/optimizer_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,58 @@ def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):
128128

129129
self.assertTrue(new_loss < initial_loss)
130130

131+
132+
@parameterized.product(
133+
module_cls=[nnx.Linear, Model],
134+
jit_decorator=[lambda f: f, nnx.jit, jax.jit],
135+
optimizer=[optax.lbfgs],
136+
)
137+
def test_jit_linesearch(self, module_cls, jit_decorator, optimizer):
138+
x = jax.random.normal(jax.random.key(0), (1, 2))
139+
y = jnp.ones((1, 4))
140+
model = module_cls(2, 4, rngs=nnx.Rngs(0))
141+
tx = optimizer(
142+
1e-3
143+
)
144+
state = nnx.Optimizer(model, tx)
145+
146+
if jit_decorator == jax.jit:
147+
model_static, model_state = nnx.split(state.model)
148+
loss_fn = lambda graphdef, state, x, y: (
149+
(nnx.merge(graphdef, state)(x) - y) ** 2
150+
).mean()
151+
initial_loss = loss_fn(model_static, model_state, x, y)
152+
153+
def jax_jit_train_step(graphdef, state, x, y):
154+
state = nnx.merge(graphdef, state)
155+
model_static, model_state = nnx.split(state.model)
156+
grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y)
157+
state.update(grads, grad = grads, value = initial_loss, value_fn = lambda state: loss_fn(model_static, state, x, y))
158+
return nnx.split(state)
159+
160+
graphdef, state = jit_decorator(jax_jit_train_step)(
161+
*nnx.split(state), x, y
162+
)
163+
state = nnx.merge(graphdef, state)
164+
new_loss = loss_fn(*nnx.split(state.model), x, y)
165+
166+
else:
167+
graphdef = nnx.graphdef(model)
168+
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
169+
170+
loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)
171+
172+
initial_loss = loss_fn(state.model, x, y)
173+
174+
def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):
175+
grads = nnx.grad(loss_fn)(optimizer.model, x, y)
176+
optimizer.update(grads, grad = grads, value = initial_loss, value_fn = loss_fn_split)
177+
178+
jit_decorator(nnx_jit_train_step)(state, x, y)
179+
new_loss = loss_fn(state.model, x, y)
180+
181+
self.assertTrue(new_loss < initial_loss)
182+
131183
@parameterized.product(
132184
module_cls=[nnx.Linear, Model],
133185
optimizer=[optax.sgd, optax.adam],
@@ -203,6 +255,55 @@ def test_wrt_update(self, variable):
203255
)
204256
)
205257

258+
@parameterized.parameters(
259+
{'variable': nnx.Param},
260+
#{'variable': nnx.LoRAParam},
261+
{'variable': (nnx.Param, nnx.LoRAParam)},
262+
)
263+
def test_wrt_update_linesearch(self, variable):
264+
in_features = 4
265+
out_features = 10
266+
model = nnx.LoRA(
267+
in_features=in_features,
268+
lora_rank=2,
269+
out_features=out_features,
270+
base_module=Model(
271+
in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0)
272+
),
273+
rngs=nnx.Rngs(1),
274+
)
275+
state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable)
276+
prev_variables, prev_other_variables = nnx.state(model, variable, ...)
277+
278+
x = jnp.ones((1, 4))
279+
y = jnp.ones((1, 10))
280+
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
281+
282+
grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
283+
state.model, x, y
284+
)
285+
initial_loss = loss_fn(model, x, y)
286+
graphdef = nnx.graphdef(model)
287+
loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)
288+
289+
state.update(grads, grad=grads, value_fn = loss_fn_split, value = initial_loss)
290+
self.assertTrue(loss_fn(model, x, y) < initial_loss)
291+
292+
# make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged
293+
variables, other_variables = nnx.state(model, variable, ...)
294+
self.assertTrue(
295+
jax.tree.all(
296+
jax.tree.map(lambda x, y: (x != y).all(), prev_variables, variables)
297+
)
298+
)
299+
if other_variables:
300+
self.assertTrue(
301+
jax.tree.all(
302+
jax.tree.map(
303+
lambda x, y: (x == y).all(), prev_other_variables, other_variables
304+
)
305+
)
306+
)
206307

207308
if __name__ == '__main__':
208309
absltest.main()

0 commit comments

Comments
 (0)