@@ -128,6 +128,58 @@ def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):
128128
129129 self .assertTrue (new_loss < initial_loss )
130130
131+
132+ @parameterized .product (
133+ module_cls = [nnx .Linear , Model ],
134+ jit_decorator = [lambda f : f , nnx .jit , jax .jit ],
135+ optimizer = [optax .lbfgs ],
136+ )
137+ def test_jit_linesearch (self , module_cls , jit_decorator , optimizer ):
138+ x = jax .random .normal (jax .random .key (0 ), (1 , 2 ))
139+ y = jnp .ones ((1 , 4 ))
140+ model = module_cls (2 , 4 , rngs = nnx .Rngs (0 ))
141+ tx = optimizer (
142+ 1e-3
143+ )
144+ state = nnx .Optimizer (model , tx )
145+
146+ if jit_decorator == jax .jit :
147+ model_static , model_state = nnx .split (state .model )
148+ loss_fn = lambda graphdef , state , x , y : (
149+ (nnx .merge (graphdef , state )(x ) - y ) ** 2
150+ ).mean ()
151+ initial_loss = loss_fn (model_static , model_state , x , y )
152+
153+ def jax_jit_train_step (graphdef , state , x , y ):
154+ state = nnx .merge (graphdef , state )
155+ model_static , model_state = nnx .split (state .model )
156+ grads = jax .grad (loss_fn , argnums = 1 )(model_static , model_state , x , y )
157+ state .update (grads , grad = grads , value = initial_loss , value_fn = lambda state : loss_fn (model_static , state , x , y ))
158+ return nnx .split (state )
159+
160+ graphdef , state = jit_decorator (jax_jit_train_step )(
161+ * nnx .split (state ), x , y
162+ )
163+ state = nnx .merge (graphdef , state )
164+ new_loss = loss_fn (* nnx .split (state .model ), x , y )
165+
166+ else :
167+ graphdef = nnx .graphdef (model )
168+ loss_fn = lambda model , x , y : ((model (x ) - y ) ** 2 ).mean ()
169+
170+ loss_fn_split = lambda state : loss_fn (nnx .merge (graphdef , state ), x , y )
171+
172+ initial_loss = loss_fn (state .model , x , y )
173+
174+ def nnx_jit_train_step (optimizer : nnx .Optimizer , x , y ):
175+ grads = nnx .grad (loss_fn )(optimizer .model , x , y )
176+ optimizer .update (grads , grad = grads , value = initial_loss , value_fn = loss_fn_split )
177+
178+ jit_decorator (nnx_jit_train_step )(state , x , y )
179+ new_loss = loss_fn (state .model , x , y )
180+
181+ self .assertTrue (new_loss < initial_loss )
182+
131183 @parameterized .product (
132184 module_cls = [nnx .Linear , Model ],
133185 optimizer = [optax .sgd , optax .adam ],
@@ -203,6 +255,55 @@ def test_wrt_update(self, variable):
203255 )
204256 )
205257
258+ @parameterized .parameters (
259+ {'variable' : nnx .Param },
260+ #{'variable': nnx.LoRAParam},
261+ {'variable' : (nnx .Param , nnx .LoRAParam )},
262+ )
263+ def test_wrt_update_linesearch (self , variable ):
264+ in_features = 4
265+ out_features = 10
266+ model = nnx .LoRA (
267+ in_features = in_features ,
268+ lora_rank = 2 ,
269+ out_features = out_features ,
270+ base_module = Model (
271+ in_features = in_features , out_features = out_features , rngs = nnx .Rngs (0 )
272+ ),
273+ rngs = nnx .Rngs (1 ),
274+ )
275+ state = nnx .Optimizer (model , optax .lbfgs (), wrt = variable )
276+ prev_variables , prev_other_variables = nnx .state (model , variable , ...)
277+
278+ x = jnp .ones ((1 , 4 ))
279+ y = jnp .ones ((1 , 10 ))
280+ loss_fn = lambda model , x , y : ((model (x ) - y ) ** 2 ).mean ()
281+
282+ grads = nnx .grad (loss_fn , argnums = nnx .DiffState (0 , variable ))(
283+ state .model , x , y
284+ )
285+ initial_loss = loss_fn (model , x , y )
286+ graphdef = nnx .graphdef (model )
287+ loss_fn_split = lambda state : loss_fn (nnx .merge (graphdef , state ), x , y )
288+
289+ state .update (grads , grad = grads , value_fn = loss_fn_split , value = initial_loss )
290+ self .assertTrue (loss_fn (model , x , y ) < initial_loss )
291+
292+ # make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged
293+ variables , other_variables = nnx .state (model , variable , ...)
294+ self .assertTrue (
295+ jax .tree .all (
296+ jax .tree .map (lambda x , y : (x != y ).all (), prev_variables , variables )
297+ )
298+ )
299+ if other_variables :
300+ self .assertTrue (
301+ jax .tree .all (
302+ jax .tree .map (
303+ lambda x , y : (x == y ).all (), prev_other_variables , other_variables
304+ )
305+ )
306+ )
206307
207308if __name__ == '__main__' :
208309 absltest .main ()
0 commit comments