Skip to content

Commit b41470a

Browse files
s-noghabiThe tunix Authors
authored andcommitted
stream based gradient accumulation
PiperOrigin-RevId: 914055918
1 parent a93fc65 commit b41470a

5 files changed

Lines changed: 120 additions & 26 deletions

File tree

tunix/rl/agentic/agentic_rl_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,9 @@ def train(
768768
self._training_config.max_seq_token_per_tpu,
769769
)
770770
train_data_gen = rl_utils.pack_sequences(
771-
train_data_gen, self._training_config.max_seq_token_per_tpu
771+
train_data_gen,
772+
self._training_config.max_seq_token_per_tpu,
773+
target_items_per_update=grad_acc_steps,
772774
)
773775
micro_batches_since_last_sync = 0
774776
micro_batches_per_full_batch = full_batch_size // train_micro_batch_size

tunix/rl/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class TrainExample:
105105
old_per_token_logps: jax.Array | None
106106
segment_ids: jax.Array | None = None
107107
segment_positions: jax.Array | None = None
108+
is_update_step: jax.Array | None = None
108109

109110

110111
def compute_kl_divergence(

tunix/rl/rl_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,9 @@ def queue_iterator():
724724
self._training_config.max_seq_token_per_tpu,
725725
)
726726
train_data_gen = rl_utils.pack_sequences(
727-
train_data_gen, self._training_config.max_seq_token_per_tpu
727+
train_data_gen,
728+
self._training_config.max_seq_token_per_tpu,
729+
target_items_per_update=grad_acc_steps,
728730
)
729731

730732
curr_eval_ds = None

tunix/rl/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,15 @@ def pack_sequences(
335335
item_iterator: Iterator[list[common.TrainExample]],
336336
max_token_budget: int,
337337
pad_id: int = 0,
338+
target_queue_items_per_update: int | None = None,
338339
) -> Iterator[list[common.TrainExample]]:
339340
"""Packs a stream of TrainExamples into 1D sequences up to a token budget."""
340341
buffer = []
341342
current_tokens = 0
342343
example_cls = common.TrainExample
344+
accumulated_queue_items = 0
343345

344-
def _flush_buffer() -> list[common.TrainExample]:
346+
def _flush_buffer(is_update: bool = False) -> list[common.TrainExample]:
345347
nonlocal buffer, current_tokens
346348
if not buffer:
347349
return []
@@ -429,13 +431,16 @@ def _pad(arr_list, val, length):
429431
if has_policy_version:
430432
kwargs["policy_version"] = buffer[0]["policy_version"]
431433

434+
kwargs["is_update_step"] = jnp.array(is_update, dtype=jnp.bool_)
435+
432436
packed_example = example_cls(**kwargs) # pytype: disable=wrong-keyword-args
433437

434438
buffer.clear()
435439
current_tokens = 0
436440
return [packed_example]
437441

438442
for item_list in item_iterator:
443+
accumulated_queue_items += 1
439444
for example in item_list:
440445
example_cls = type(example)
441446
unpadded_items = unpad_train_example(example)
@@ -453,13 +458,21 @@ def _pad(arr_list, val, length):
453458
continue
454459

455460
if current_tokens + tokens > max_token_budget:
456-
yield _flush_buffer()
461+
# Flush normally. The final batch logic below will trigger is_update=True.
462+
yield _flush_buffer(is_update=False)
457463

458464
buffer.append(item)
459465
current_tokens += tokens
460466

467+
if (
468+
target_queue_items_per_update
469+
and accumulated_queue_items >= target_queue_items_per_update
470+
):
471+
yield _flush_buffer(is_update=True)
472+
accumulated_queue_items = 0
473+
461474
if buffer:
462-
yield _flush_buffer()
475+
yield _flush_buffer(is_update=True)
463476

464477

465478
VERIFY_UPDATE_PARAMS_KEY = "VERIFY_UPDATE_PARAMS_SRC_TO_TGT_MODULE_NAME"

tunix/sft/peft_trainer.py

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,35 @@ def _calculate_global_batch_size(train_example: Any) -> int:
164164
)
165165

166166

167+
class AccGrad(nnx.Variable):
168+
pass
169+
170+
171+
class GradientAccumulator(nnx.Module):
172+
"""Accumulates gradients manually."""
173+
174+
def __init__(self, model: nnx.Module, wrt: Any):
175+
state = nnx.state(model, wrt)
176+
self.grads = jax.tree_util.tree_map(
177+
lambda x: AccGrad(jnp.zeros_like(x)), state
178+
)
179+
180+
def add(self, grads: Any):
181+
def _add(acc, g):
182+
acc.value = acc.value + g
183+
184+
jax.tree_util.tree_map(_add, self.grads, grads)
185+
186+
def get(self):
187+
return jax.tree_util.tree_map(lambda x: x.value, self.grads)
188+
189+
def reset(self):
190+
def _reset(acc):
191+
acc.value = jnp.zeros_like(acc.value)
192+
193+
jax.tree_util.tree_map(_reset, self.grads)
194+
195+
167196
class PeftTrainer:
168197
"""PEFT trainer for LoRA. Only LoRA parameters are updated.
169198
@@ -186,7 +215,7 @@ class PeftTrainer:
186215
data_hooks: The data hooks to use.
187216
"""
188217

189-
supports_sequence_packing = False
218+
supports_sequence_packing = True
190219

191220
def __init__(
192221
self,
@@ -209,14 +238,9 @@ def __init__(
209238
self.model = model
210239
self.config = training_config
211240
self._lora_enabled = utils.is_lora_enabled(self.model)
212-
if training_config.gradient_accumulation_steps is not None:
213-
optimizer = optax.MultiSteps(
214-
optimizer, training_config.gradient_accumulation_steps
215-
)
216-
if self._lora_enabled:
217-
self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=nnx.LoRAParam)
218-
else:
219-
self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=nnx.Param)
241+
wrt_target = nnx.LoRAParam if self._lora_enabled else nnx.Param
242+
self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=wrt_target)
243+
self.grad_accumulator = GradientAccumulator(self.model, wrt_target)
220244

221245
self.loss_fn = _default_loss_fn
222246
self.eval_loss_fn = _default_loss_fn
@@ -329,14 +353,21 @@ def with_gen_model_input_fn(
329353
return self
330354

331355
def _train_step(
332-
self, model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any
356+
self,
357+
model: nnx.Module,
358+
optimizer: nnx.Optimizer,
359+
grad_accumulator: GradientAccumulator,
360+
inputs: Any,
361+
is_update_step: jax.Array,
333362
) -> Tuple[ArrayLike, Any | None, ArrayLike]:
334363
"""Main body for one train step.
335364
336365
Args:
337366
model: The model to train.
338367
optimizer: The optimizer to use.
368+
grad_accumulator: The gradient accumulator to use.
339369
inputs: The training input.
370+
is_update_step: Whether to update the model.
340371
341372
Returns:
342373
A tuple containing the loss, auxiliary data (or None if has_aux is False),
@@ -350,8 +381,21 @@ def _train_step(
350381
has_aux=self._has_aux,
351382
)
352383
out, grads = grad_fn(model, **inputs)
353-
grad_norm = optax.global_norm(grads)
354-
optimizer.update(model, grads)
384+
385+
grad_accumulator.add(grads)
386+
387+
def apply_updates():
388+
acc_grads = grad_accumulator.get()
389+
norm = optax.global_norm(acc_grads)
390+
optimizer.update(model, acc_grads)
391+
grad_accumulator.reset()
392+
return norm
393+
394+
def skip_updates():
395+
return jnp.array(0.0, dtype=jnp.float32)
396+
397+
grad_norm = jax.lax.cond(is_update_step, apply_updates, skip_updates)
398+
355399
if self._has_aux:
356400
loss, aux = out
357401
return loss, aux, grad_norm
@@ -397,6 +441,15 @@ def _shard_optimizer(self, mesh: shd.Mesh) -> None:
397441
)
398442
nnx.update(self.optimizer, optimizer_sharded_state)
399443

444+
wrt_target = nnx.LoRAParam if self._lora_enabled else nnx.Param
445+
model_state = nnx.state(self.model, wrt_target)
446+
model_pspecs = nnx.get_partition_spec(model_state)
447+
accumulator_state = nnx.state(self.grad_accumulator, AccGrad)
448+
accumulator_sharded_state = jax.lax.with_sharding_constraint(
449+
accumulator_state, model_pspecs
450+
)
451+
nnx.update(self.grad_accumulator, accumulator_sharded_state)
452+
400453
def jit_train_and_eval_step(
401454
self, skip_jit: bool = False, cache_nnx_graph: bool = False
402455
):
@@ -419,7 +472,7 @@ def jit_train_and_eval_step(
419472
if self._jitted_train_step_fn is None:
420473
self._shard_optimizer(pxla.thread_resources.env.physical_mesh)
421474
self._jitted_train_step_fn = nnx.jit(
422-
train_step, donate_argnames=("optimizer",)
475+
train_step, donate_argnames=("optimizer", "grad_accumulator")
423476
)
424477
self._jitted_eval_step_fn = nnx.jit(eval_step)
425478

@@ -431,7 +484,10 @@ def maybe_cache_and_partial(f, *args):
431484
return functools.partial(f, *args)
432485

433486
self._jitted_train_step_fn = maybe_cache_and_partial(
434-
self._jitted_train_step_fn, self.model, self.optimizer
487+
self._jitted_train_step_fn,
488+
self.model,
489+
self.optimizer,
490+
self.grad_accumulator,
435491
)
436492
self._jitted_eval_step_fn = maybe_cache_and_partial(
437493
self._jitted_eval_step_fn, self.model
@@ -695,6 +751,28 @@ def train(
695751
perf_constants.MINI_BATCH: mini_batch,
696752
}
697753

754+
self._iter_steps += 1
755+
756+
is_update_step_val = None
757+
if (
758+
isinstance(train_example, dict)
759+
and "is_update_step" in train_example
760+
):
761+
val = train_example["is_update_step"]
762+
if val is not None:
763+
is_update_step_val = bool(np.asarray(val).item())
764+
elif hasattr(train_example, "is_update_step"):
765+
val = train_example.is_update_step
766+
if val is not None:
767+
is_update_step_val = bool(np.asarray(val).item())
768+
769+
if is_update_step_val is None:
770+
is_update_step_val = (
771+
self._iter_steps
772+
% self.config.get_with_default("gradient_accumulation_steps", 1)
773+
== 0
774+
)
775+
698776
with self._perf_tracer.span(
699777
"peft_train_step",
700778
pxla.thread_resources.env.physical_mesh.devices,
@@ -703,7 +781,10 @@ def train(
703781
pxla.thread_resources.env.physical_mesh.devices,
704782
tags=tags,
705783
) as span_v2:
706-
train_loss, aux, grad_norm = train_step(train_example)
784+
train_loss, aux, grad_norm = train_step(
785+
train_example,
786+
is_update_step=jnp.array(is_update_step_val, dtype=jnp.bool_),
787+
)
707788
span.device_end([train_loss])
708789
span_v2.async_end([train_loss])
709790

@@ -716,13 +797,8 @@ def train(
716797
)
717798
# NB: put this after self._buffer_metrics is important
718799
self._post_process_train_step(aux)
719-
self._iter_steps += 1
720800

721-
if (
722-
self._iter_steps
723-
% self.config.get_with_default("gradient_accumulation_steps", 1)
724-
== 0
725-
):
801+
if is_update_step_val:
726802
self._train_steps += 1
727803
self._write_train_metrics()
728804

0 commit comments

Comments
 (0)