Skip to content

Commit f2dcd33

Browse files
tianshubThe tunix Authors
authored andcommitted
speedup trainer for RL
PiperOrigin-RevId: 875877560
1 parent 31720ac commit f2dcd33

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

tunix/sft/peft_trainer.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Iterable
1818
import contextlib
1919
import dataclasses
20+
import functools
2021
import time
2122
from typing import Any, Callable, Concatenate, Dict, List, Optional, ParamSpec, Tuple
2223

@@ -376,13 +377,16 @@ def _shard_optimizer(self, mesh: shd.Mesh) -> None:
376377
)
377378
nnx.update(self.optimizer, optimizer_sharded_state)
378379

379-
def jit_train_and_eval_step(self, skip_jit: bool = False):
380+
def jit_train_and_eval_step(
381+
self, skip_jit: bool = False, cache_nnx_graph: bool = False
382+
):
380383
"""Creates and returns the train and eval step functions.
381384
382385
This function will return the cached ones if available.
383386
384387
Args:
385388
skip_jit: If True, the train and eval step functions will not be JITed.
389+
cache_nnx_graph: If True, the nnx graph will be cached.
386390
387391
Returns:
388392
A tuple of train and eval step functions.
@@ -391,14 +395,28 @@ def jit_train_and_eval_step(self, skip_jit: bool = False):
391395
eval_step = self.create_eval_step_fn()
392396
if skip_jit:
393397
return train_step, eval_step
394-
else:
395-
if self._jitted_train_step_fn is None:
396-
self._shard_optimizer(pxla.thread_resources.env.physical_mesh)
397-
self._jitted_train_step_fn = nnx.jit(
398-
train_step, donate_argnames=("optimizer",)
399-
)
400-
self._jitted_eval_step_fn = nnx.jit(eval_step)
401-
return self._jitted_train_step_fn, self._jitted_eval_step_fn
398+
399+
if self._jitted_train_step_fn is None:
400+
self._shard_optimizer(pxla.thread_resources.env.physical_mesh)
401+
self._jitted_train_step_fn = nnx.jit(
402+
train_step, donate_argnames=("optimizer",)
403+
)
404+
self._jitted_eval_step_fn = nnx.jit(eval_step)
405+
406+
def maybe_cache_and_partial(f, *args):
407+
if cache_nnx_graph:
408+
# wrap with partial so we can access jitted_fn in a consistent way.
409+
return functools.partial(nnx.cached_partial(f, *args))
410+
else:
411+
return functools.partial(f, *args)
412+
413+
self._jitted_train_step_fn = maybe_cache_and_partial(
414+
self._jitted_train_step_fn, self.model, self.optimizer
415+
)
416+
self._jitted_eval_step_fn = maybe_cache_and_partial(
417+
self._jitted_eval_step_fn, self.model
418+
)
419+
return self._jitted_train_step_fn, self._jitted_eval_step_fn
402420

403421
def _prepare_inputs(self, input_data: Any) -> Any:
404422
"""Override this function for additional input preparation."""
@@ -565,41 +583,28 @@ def train(
565583
eval_ds: Iterable[Any] | None = None,
566584
skip_jit: bool = False,
567585
*,
568-
cache_nnx_graph: bool = False,
586+
cache_nnx_graph: bool = True,
569587
) -> None:
570588
"""Training loop."""
571589
logging.log_first_n(
572590
logging.INFO,
573591
f"Training with mesh: {pxla.thread_resources.env.physical_mesh}",
574592
1,
575593
)
576-
train_step, eval_step = self.jit_train_and_eval_step(skip_jit)
594+
train_step, eval_step = self.jit_train_and_eval_step(
595+
skip_jit, cache_nnx_graph
596+
)
577597
if not skip_jit:
578-
cache_size = train_step.jitted_fn._cache_size() # pytype: disable=attribute-error
598+
cache_size = train_step.func.jitted_fn._cache_size() # pytype: disable=attribute-error
579599
logging.log_if(
580600
logging.INFO,
581601
f"Compiled train_step cache size: {cache_size}",
582602
condition=cache_size not in self._jit_cache,
583603
)
584604
self._jit_cache.add(cache_size)
585605

586-
if cache_nnx_graph:
587-
# For performance, cache the nnx graph traversals. However, the training
588-
# loop must _not_ modify the model or optimizer graph in this case. For
589-
# example, the distillation trainer mutates the model graph by adding the
590-
# distillation loss.
591-
partial_train_step = nnx.cached_partial(
592-
train_step, self.model, self.optimizer
593-
)
594-
partial_eval_step = nnx.cached_partial(eval_step, self.model)
595-
else:
596-
partial_train_step = lambda inputs: train_step(
597-
self.model, self.optimizer, inputs
598-
)
599-
partial_eval_step = lambda inputs: eval_step(self.model, inputs)
600-
601606
if eval_ds:
602-
self._run_eval(eval_ds, partial_eval_step)
607+
self._run_eval(eval_ds, eval_step)
603608

604609
if self.config.max_steps is not None and self._pbar is None:
605610
self._pbar = progress_bar.ProgressBar(
@@ -679,7 +684,7 @@ def train(
679684
with self._perf_tracer.span(
680685
"peft_train_step", pxla.thread_resources.env.physical_mesh.devices
681686
) as span:
682-
train_loss, aux = partial_train_step(train_example)
687+
train_loss, aux = train_step(train_example)
683688
span.device_end([train_loss])
684689

685690
current_time = time.perf_counter()
@@ -718,7 +723,7 @@ def train(
718723
eval_ds
719724
and self._train_steps % self.config.eval_every_n_steps == 0
720725
):
721-
self._run_eval(eval_ds, partial_eval_step)
726+
self._run_eval(eval_ds, eval_step)
722727

723728
self._prof.maybe_deactivate(self._iter_steps)
724729

0 commit comments

Comments
 (0)