1717from collections .abc import Iterable
1818import contextlib
1919import dataclasses
20+ import functools
2021import time
2122from typing import Any , Callable , Concatenate , Dict , List , Optional , ParamSpec , Tuple
2223
@@ -376,13 +377,16 @@ def _shard_optimizer(self, mesh: shd.Mesh) -> None:
376377 )
377378 nnx .update (self .optimizer , optimizer_sharded_state )
378379
379- def jit_train_and_eval_step (self , skip_jit : bool = False ):
380+ def jit_train_and_eval_step (
381+ self , skip_jit : bool = False , cache_nnx_graph : bool = False
382+ ):
380383 """Creates and returns the train and eval step functions.
381384
382385 This function will return the cached ones if available.
383386
384387 Args:
385388 skip_jit: If True, the train and eval step functions will not be JITed.
389+ cache_nnx_graph: If True, the nnx graph will be cached.
386390
387391 Returns:
388392 A tuple of train and eval step functions.
@@ -391,14 +395,28 @@ def jit_train_and_eval_step(self, skip_jit: bool = False):
391395 eval_step = self .create_eval_step_fn ()
392396 if skip_jit :
393397 return train_step , eval_step
394- else :
395- if self ._jitted_train_step_fn is None :
396- self ._shard_optimizer (pxla .thread_resources .env .physical_mesh )
397- self ._jitted_train_step_fn = nnx .jit (
398- train_step , donate_argnames = ("optimizer" ,)
399- )
400- self ._jitted_eval_step_fn = nnx .jit (eval_step )
401- return self ._jitted_train_step_fn , self ._jitted_eval_step_fn
398+
399+ if self ._jitted_train_step_fn is None :
400+ self ._shard_optimizer (pxla .thread_resources .env .physical_mesh )
401+ self ._jitted_train_step_fn = nnx .jit (
402+ train_step , donate_argnames = ("optimizer" ,)
403+ )
404+ self ._jitted_eval_step_fn = nnx .jit (eval_step )
405+
406+ def maybe_cache_and_partial (f , * args ):
407+ if cache_nnx_graph :
408+ # wrap with partial so we can access jitted_fn in a consistent way.
409+ return functools .partial (nnx .cached_partial (f , * args ))
410+ else :
411+ return functools .partial (f , * args )
412+
413+ self ._jitted_train_step_fn = maybe_cache_and_partial (
414+ self ._jitted_train_step_fn , self .model , self .optimizer
415+ )
416+ self ._jitted_eval_step_fn = maybe_cache_and_partial (
417+ self ._jitted_eval_step_fn , self .model
418+ )
419+ return self ._jitted_train_step_fn , self ._jitted_eval_step_fn
402420
403421 def _prepare_inputs (self , input_data : Any ) -> Any :
404422 """Override this function for additional input preparation."""
@@ -565,41 +583,28 @@ def train(
565583 eval_ds : Iterable [Any ] | None = None ,
566584 skip_jit : bool = False ,
567585 * ,
568- cache_nnx_graph : bool = False ,
586+ cache_nnx_graph : bool = True ,
569587 ) -> None :
570588 """Training loop."""
571589 logging .log_first_n (
572590 logging .INFO ,
573591 f"Training with mesh: { pxla .thread_resources .env .physical_mesh } " ,
574592 1 ,
575593 )
576- train_step , eval_step = self .jit_train_and_eval_step (skip_jit )
594+ train_step , eval_step = self .jit_train_and_eval_step (
595+ skip_jit , cache_nnx_graph
596+ )
577597 if not skip_jit :
578- cache_size = train_step .jitted_fn ._cache_size () # pytype: disable=attribute-error
598+ cache_size = train_step .func . jitted_fn ._cache_size () # pytype: disable=attribute-error
579599 logging .log_if (
580600 logging .INFO ,
581601 f"Compiled train_step cache size: { cache_size } " ,
582602 condition = cache_size not in self ._jit_cache ,
583603 )
584604 self ._jit_cache .add (cache_size )
585605
586- if cache_nnx_graph :
587- # For performance, cache the nnx graph traversals. However, the training
588- # loop must _not_ modify the model or optimizer graph in this case. For
589- # example, the distillation trainer mutates the model graph by adding the
590- # distillation loss.
591- partial_train_step = nnx .cached_partial (
592- train_step , self .model , self .optimizer
593- )
594- partial_eval_step = nnx .cached_partial (eval_step , self .model )
595- else :
596- partial_train_step = lambda inputs : train_step (
597- self .model , self .optimizer , inputs
598- )
599- partial_eval_step = lambda inputs : eval_step (self .model , inputs )
600-
601606 if eval_ds :
602- self ._run_eval (eval_ds , partial_eval_step )
607+ self ._run_eval (eval_ds , eval_step )
603608
604609 if self .config .max_steps is not None and self ._pbar is None :
605610 self ._pbar = progress_bar .ProgressBar (
@@ -679,7 +684,7 @@ def train(
679684 with self ._perf_tracer .span (
680685 "peft_train_step" , pxla .thread_resources .env .physical_mesh .devices
681686 ) as span :
682- train_loss , aux = partial_train_step (train_example )
687+ train_loss , aux = train_step (train_example )
683688 span .device_end ([train_loss ])
684689
685690 current_time = time .perf_counter ()
@@ -718,7 +723,7 @@ def train(
718723 eval_ds
719724 and self ._train_steps % self .config .eval_every_n_steps == 0
720725 ):
721- self ._run_eval (eval_ds , partial_eval_step )
726+ self ._run_eval (eval_ds , eval_step )
722727
723728 self ._prof .maybe_deactivate (self ._iter_steps )
724729
0 commit comments