@@ -164,6 +164,35 @@ def _calculate_global_batch_size(train_example: Any) -> int:
164164 )
165165
166166
167+ class AccGrad (nnx .Variable ):
168+ pass
169+
170+
171+ class GradientAccumulator (nnx .Module ):
172+ """Accumulates gradients manually."""
173+
174+ def __init__ (self , model : nnx .Module , wrt : Any ):
175+ state = nnx .state (model , wrt )
176+ self .grads = jax .tree_util .tree_map (
177+ lambda x : AccGrad (jnp .zeros_like (x )), state
178+ )
179+
180+ def add (self , grads : Any ):
181+ def _add (acc , g ):
182+ acc .value = acc .value + g
183+
184+ jax .tree_util .tree_map (_add , self .grads , grads )
185+
186+ def get (self ):
187+ return jax .tree_util .tree_map (lambda x : x .value , self .grads )
188+
189+ def reset (self ):
190+ def _reset (acc ):
191+ acc .value = jnp .zeros_like (acc .value )
192+
193+ jax .tree_util .tree_map (_reset , self .grads )
194+
195+
167196class PeftTrainer :
168197 """PEFT trainer for LoRA. Only LoRA parameters are updated.
169198
@@ -186,7 +215,7 @@ class PeftTrainer:
186215 data_hooks: The data hooks to use.
187216 """
188217
189- supports_sequence_packing = False
218+ supports_sequence_packing = True
190219
191220 def __init__ (
192221 self ,
@@ -209,14 +238,9 @@ def __init__(
209238 self .model = model
210239 self .config = training_config
211240 self ._lora_enabled = utils .is_lora_enabled (self .model )
212- if training_config .gradient_accumulation_steps is not None :
213- optimizer = optax .MultiSteps (
214- optimizer , training_config .gradient_accumulation_steps
215- )
216- if self ._lora_enabled :
217- self .optimizer = nnx .Optimizer (self .model , optimizer , wrt = nnx .LoRAParam )
218- else :
219- self .optimizer = nnx .Optimizer (self .model , optimizer , wrt = nnx .Param )
241+ wrt_target = nnx .LoRAParam if self ._lora_enabled else nnx .Param
242+ self .optimizer = nnx .Optimizer (self .model , optimizer , wrt = wrt_target )
243+ self .grad_accumulator = GradientAccumulator (self .model , wrt_target )
220244
221245 self .loss_fn = _default_loss_fn
222246 self .eval_loss_fn = _default_loss_fn
@@ -329,14 +353,21 @@ def with_gen_model_input_fn(
329353 return self
330354
331355 def _train_step (
332- self , model : nnx .Module , optimizer : nnx .Optimizer , inputs : Any
356+ self ,
357+ model : nnx .Module ,
358+ optimizer : nnx .Optimizer ,
359+ grad_accumulator : GradientAccumulator ,
360+ inputs : Any ,
361+ is_update_step : jax .Array ,
333362 ) -> Tuple [ArrayLike , Any | None , ArrayLike ]:
334363 """Main body for one train step.
335364
336365 Args:
337366 model: The model to train.
338367 optimizer: The optimizer to use.
368+ grad_accumulator: The gradient accumulator to use.
339369 inputs: The training input.
370+ is_update_step: Whether to update the model.
340371
341372 Returns:
342373 A tuple containing the loss, auxiliary data (or None if has_aux is False),
@@ -350,8 +381,21 @@ def _train_step(
350381 has_aux = self ._has_aux ,
351382 )
352383 out , grads = grad_fn (model , ** inputs )
353- grad_norm = optax .global_norm (grads )
354- optimizer .update (model , grads )
384+
385+ grad_accumulator .add (grads )
386+
387+ def apply_updates ():
388+ acc_grads = grad_accumulator .get ()
389+ norm = optax .global_norm (acc_grads )
390+ optimizer .update (model , acc_grads )
391+ grad_accumulator .reset ()
392+ return norm
393+
394+ def skip_updates ():
395+ return jnp .array (0.0 , dtype = jnp .float32 )
396+
397+ grad_norm = jax .lax .cond (is_update_step , apply_updates , skip_updates )
398+
355399 if self ._has_aux :
356400 loss , aux = out
357401 return loss , aux , grad_norm
@@ -397,6 +441,15 @@ def _shard_optimizer(self, mesh: shd.Mesh) -> None:
397441 )
398442 nnx .update (self .optimizer , optimizer_sharded_state )
399443
444+ wrt_target = nnx .LoRAParam if self ._lora_enabled else nnx .Param
445+ model_state = nnx .state (self .model , wrt_target )
446+ model_pspecs = nnx .get_partition_spec (model_state )
447+ accumulator_state = nnx .state (self .grad_accumulator , AccGrad )
448+ accumulator_sharded_state = jax .lax .with_sharding_constraint (
449+ accumulator_state , model_pspecs
450+ )
451+ nnx .update (self .grad_accumulator , accumulator_sharded_state )
452+
400453 def jit_train_and_eval_step (
401454 self , skip_jit : bool = False , cache_nnx_graph : bool = False
402455 ):
@@ -419,7 +472,7 @@ def jit_train_and_eval_step(
419472 if self ._jitted_train_step_fn is None :
420473 self ._shard_optimizer (pxla .thread_resources .env .physical_mesh )
421474 self ._jitted_train_step_fn = nnx .jit (
422- train_step , donate_argnames = ("optimizer" ,)
475+ train_step , donate_argnames = ("optimizer" , "grad_accumulator" )
423476 )
424477 self ._jitted_eval_step_fn = nnx .jit (eval_step )
425478
@@ -431,7 +484,10 @@ def maybe_cache_and_partial(f, *args):
431484 return functools .partial (f , * args )
432485
433486 self ._jitted_train_step_fn = maybe_cache_and_partial (
434- self ._jitted_train_step_fn , self .model , self .optimizer
487+ self ._jitted_train_step_fn ,
488+ self .model ,
489+ self .optimizer ,
490+ self .grad_accumulator ,
435491 )
436492 self ._jitted_eval_step_fn = maybe_cache_and_partial (
437493 self ._jitted_eval_step_fn , self .model
@@ -695,6 +751,28 @@ def train(
695751 perf_constants .MINI_BATCH : mini_batch ,
696752 }
697753
754+ self ._iter_steps += 1
755+
756+ is_update_step_val = None
757+ if (
758+ isinstance (train_example , dict )
759+ and "is_update_step" in train_example
760+ ):
761+ val = train_example ["is_update_step" ]
762+ if val is not None :
763+ is_update_step_val = bool (np .asarray (val ).item ())
764+ elif hasattr (train_example , "is_update_step" ):
765+ val = train_example .is_update_step
766+ if val is not None :
767+ is_update_step_val = bool (np .asarray (val ).item ())
768+
769+ if is_update_step_val is None :
770+ is_update_step_val = (
771+ self ._iter_steps
772+ % self .config .get_with_default ("gradient_accumulation_steps" , 1 )
773+ == 0
774+ )
775+
698776 with self ._perf_tracer .span (
699777 "peft_train_step" ,
700778 pxla .thread_resources .env .physical_mesh .devices ,
@@ -703,7 +781,10 @@ def train(
703781 pxla .thread_resources .env .physical_mesh .devices ,
704782 tags = tags ,
705783 ) as span_v2 :
706- train_loss , aux , grad_norm = train_step (train_example )
784+ train_loss , aux , grad_norm = train_step (
785+ train_example ,
786+ is_update_step = jnp .array (is_update_step_val , dtype = jnp .bool_ ),
787+ )
707788 span .device_end ([train_loss ])
708789 span_v2 .async_end ([train_loss ])
709790
@@ -716,13 +797,8 @@ def train(
716797 )
717798 # NB: put this after self._buffer_metrics is important
718799 self ._post_process_train_step (aux )
719- self ._iter_steps += 1
720800
721- if (
722- self ._iter_steps
723- % self .config .get_with_default ("gradient_accumulation_steps" , 1 )
724- == 0
725- ):
801+ if is_update_step_val :
726802 self ._train_steps += 1
727803 self ._write_train_metrics ()
728804
0 commit comments