diff --git a/tests/sft/peft_trainer_test.py b/tests/sft/peft_trainer_test.py index 831cd101b..99992edcb 100644 --- a/tests/sft/peft_trainer_test.py +++ b/tests/sft/peft_trainer_test.py @@ -42,6 +42,8 @@ # CPU environment setup to simulate multi device env. os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' +# Set Precision to highest for numeric stability across different hardware. +jax.config.update('jax_default_matmul_precision', 'highest') def create_sharded_model(model_ctor, rngs, mesh): @nnx.jit(static_argnums=(0,)) @@ -108,10 +110,14 @@ def setUp(self): def test_compile_once(self): class CountCompiledTimesTrainer(peft_trainer.PeftTrainer): - def _train_step(self, model, optimizer, inputs): + def _train_step( + self, model, optimizer, grad_accumulator, inputs, is_update_step + ): global global_counter global_counter += 1 - return super()._train_step(model, optimizer, inputs) + return super()._train_step( + model, optimizer, grad_accumulator, inputs, is_update_step + ) config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100) rngs = nnx.Rngs(0) @@ -653,5 +659,514 @@ def test_injected_params(self): ) +def _unwrap(state): + """Unwrap a `State` of `Variable` leaves to raw arrays for numeric checks.""" + return jax.tree_util.tree_map( + lambda v: v[...] if isinstance(v, nnx.Variable) else v, + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + +class GradientAccumulatorTest(parameterized.TestCase): + """Unit tests for the GradientAccumulator module. + + Covers the unified `add(grads, denom=None)` contract: + + * default (`denom=None`): each call contributes 1.0 to the denominator, + so `get()` returns the mean of the per-micro-step gradients — the + `optax.MultiSteps` semantics expected by callers using a per-batch + scalar (mean) loss. + * explicit `denom`: caller supplies the unreduced-loss denominator + (e.g. token count). `get()` returns `Σg / Σd`, which is the gradient + of a single step on the concatenated batch — required when + micro-batches have varying effective batch sizes (sequence packing). + """ + + def _make_accumulator(self): + rngs = nnx.Rngs(0) + model = nnx.Linear(in_features=4, out_features=2, rngs=rngs) + return model, peft_trainer.GradientAccumulator(model, nnx.Param) + + def _ones_like_params(self, model, scale: float = 1.0): + # `nnx.state(model, nnx.Param)` returns a `State` whose leaves are + # `Param` Variables on flax >= 0.12; `tree_map` (no `is_leaf`) + # descends past them and yields the inner arrays. + return jax.tree_util.tree_map( + lambda x: jnp.asarray(scale, dtype=x.dtype) * jnp.ones_like(x), + nnx.state(model, nnx.Param), + ) + + def test_default_mode_averages_grads(self): + """Default add() returns the mean of micro-step grads. + + Matches ``optax.MultiSteps`` semantics: K micro-steps of size B/K are + equivalent to a single step on a batch of size B when the loss + function returns a per-batch scalar (mean) value. ``get()`` returns + ``(Σ_i grads_i) / max(Σ_i 1, 1)``; here K=2 and the per-step grads + have scale 1.0 and 2.0, so the mean is 1.5. + """ + model, acc = self._make_accumulator() + acc.add(self._ones_like_params(model, scale=1.0)) + acc.add(self._ones_like_params(model, scale=2.0)) + out = _unwrap(acc.get()) + jax.tree_util.tree_map( + lambda v: np.testing.assert_allclose(v, 1.5 * jnp.ones_like(v)), + out, + ) + + @parameterized.named_parameters( + dict(testcase_name='equal_denoms', denoms=(4.0, 4.0, 4.0, 4.0)), + dict(testcase_name='varying_denoms', denoms=(1.0, 7.0, 3.0, 5.0)), + dict(testcase_name='extreme_variance', denoms=(1.0, 1.0, 100.0, 1.0)), + ) + def test_explicit_denom_matches_single_step_baseline(self, denoms): + """Passing explicit denom matches the equivalent single-step batch. + + Setup: K micro-batches with denominator d_i and unreduced-sum + gradient g_i. The accumulator computes ``Σ_i g_i / Σ_i d_i``, which + is ``grad(Σ_i loss_unreduced_i) / Σ_i d_i`` — i.e., a single step on + the concatenated batch — for any choice of d_i. The "pre-scale grads + by 1/d_i then mean over K" pattern fails this equality when d_i are + unequal; this test guards against that regression. + """ + model, acc = self._make_accumulator() + + keys = jax.random.split(jax.random.PRNGKey(0), len(denoms)) + grads = [ + jax.tree_util.tree_map( + lambda x, k=k: jax.random.normal(k, x.shape, dtype=x.dtype), + nnx.state(model, nnx.Param), + ) + for k in keys + ] + + for g_i, d_i in zip(grads, denoms): + acc.add(g_i, denom=jnp.asarray(d_i, dtype=jnp.float32)) + accumulated = _unwrap(acc.get()) + + total_denom = sum(denoms) + expected = jax.tree_util.tree_map(lambda *gs: sum(gs) / total_denom, *grads) + jax.tree_util.tree_map( + lambda a, e: np.testing.assert_allclose(a, e, rtol=1e-6, atol=1e-6), + accumulated, + expected, + ) + + if len(set(denoms)) > 1: + naive_mean = jax.tree_util.tree_map( + lambda *gs: sum(g / d for g, d in zip(gs, denoms)) / len(gs), + *grads, + ) + diff_tree = jax.tree_util.tree_map( + lambda a, b: jnp.max(jnp.abs(a - b)), accumulated, naive_mean + ) + max_naive_diff = jax.tree_util.tree_reduce( + jnp.maximum, diff_tree, initializer=jnp.asarray(0.0, jnp.float32) + ) + self.assertGreater( + float(max_naive_diff), + 1e-3, + msg=( + 'naive pre-scale-then-mean and Sigma g / Sigma d should ' + 'disagree when denominators vary; if they agree the test setup ' + 'is degenerate.' + ), + ) + + def test_reset_clears_denom(self): + model, acc = self._make_accumulator() + acc.add(self._ones_like_params(model), denom=jnp.asarray(7.0, jnp.float32)) + acc.reset() + self.assertEqual(float(acc.denom[...]), 0.0) + jax.tree_util.tree_map( + lambda v: np.testing.assert_array_equal(v[...], jnp.zeros_like(v[...])), + acc.grads, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + # --------------------------------------------------------------------- + # End-to-end numerical equivalence tests against `nnx.value_and_grad`. + # + # The tests above exercise the accumulator with hand-rolled arrays; the + # tests below thread the *real* differentiation path (`nnx.value_and_grad` + # on a small model) so the assertions hold for the exact pytree shape / + # Variable wrappers the production `_train_step` produces. + # --------------------------------------------------------------------- + + def _make_model_and_data(self, total_examples: int, seed: int = 42): + rngs = nnx.Rngs(seed) + model = nnx.Linear(in_features=4, out_features=2, rngs=rngs) + keys = jax.random.split(jax.random.PRNGKey(seed), 2) + x = jax.random.normal(keys[0], (total_examples, 4)) + y = jax.random.normal(keys[1], (total_examples, 2)) + return model, x, y + + @staticmethod + def _loss_mean(model, x, y): + # Mean over the batch / sequence axis only (sum over feature axis) + # so `sum_loss == batch_size * mean_loss`. The full-tensor `jnp.mean` + # would divide by `batch_size * feature_dim`, which would only match + # the denom-aware path if `denom` were passed as `size * feature_dim` + # — pinning the contract to a model-architecture quirk we don't want + # the test to rely on. + per_example = jnp.sum((model(x) - y) ** 2, axis=-1) + return jnp.mean(per_example) + + @staticmethod + def _loss_sum(model, x, y): + # Matches the reduction of `_loss_mean` modulo division by batch size: + # sum over both batch and feature axes. + return jnp.sum((model(x) - y) ** 2) + + @parameterized.named_parameters( + dict(testcase_name='K1', K=1), + dict(testcase_name='K2', K=2), + dict(testcase_name='K4', K=4), + dict(testcase_name='K8', K=8), + ) + def test_default_mode_K_micro_batches_match_full_batch(self, K): + """Default mode: K equal-size micro-batches ≡ one full batch. + + Mean-of-means equals mean-over-all when the micro-batches partition + the full batch into equal-size chunks. This is the + `optax.MultiSteps`-equivalent contract the unpacked grad-accumulation + path relies on. + """ + B = 16 + self.assertEqual(B % K, 0) + micro = B // K + model, x, y = self._make_model_and_data(B) + + grad_fn = nnx.value_and_grad(self._loss_mean) + _, expected = grad_fn(model, x, y) + + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + for i in range(K): + _, g = grad_fn( + model, x[i * micro : (i + 1) * micro], y[i * micro : (i + 1) * micro] + ) + acc.add(g) + + accumulated = _unwrap(acc.get()) + expected_arrays = _unwrap(expected) + jax.tree_util.tree_map( + lambda a, e: np.testing.assert_allclose(a, e, rtol=1e-6, atol=1e-6), + accumulated, + expected_arrays, + ) + + def test_default_mode_K_micro_batches_match_concatenated_baseline_under_jit( + self, + ): + """Same as above but with the accumulator's mutations under `nnx.jit`. + + The unpacked `_train_step` calls `acc.add()` from inside a jit; this + test exercises the same trace path so any nnx.Variable / pytree + breakage in jitted mutation surfaces here (rather than only at the + full trainer integration level). + """ + B = 12 + K = 3 + micro = B // K + model, x, y = self._make_model_and_data(B, seed=7) + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + + @nnx.jit + def _add_step(model, acc, x_b, y_b): + _, g = nnx.value_and_grad(self._loss_mean)(model, x_b, y_b) + acc.add(g) + + for i in range(K): + _add_step( + model, + acc, + x[i * micro : (i + 1) * micro], + y[i * micro : (i + 1) * micro], + ) + + accumulated = _unwrap(acc.get()) + _, expected = nnx.value_and_grad(self._loss_mean)(model, x, y) + expected_arrays = _unwrap(expected) + jax.tree_util.tree_map( + lambda a, e: np.testing.assert_allclose(a, e, rtol=1e-6, atol=1e-6), + accumulated, + expected_arrays, + ) + + @parameterized.named_parameters( + dict(testcase_name='small_pack', sizes=(3, 5, 1, 7)), + dict(testcase_name='single_dominant_pack', sizes=(1, 1, 28, 2)), + dict(testcase_name='single_pack', sizes=(8,)), + dict(testcase_name='many_small_packs', sizes=(1, 1, 1, 1, 1, 1, 1, 1)), + ) + def test_explicit_denom_packed_micro_batches_match_full_batch(self, sizes): + """Sequence packing: varying-size micro-batches with denom=size. + + Under sequence packing each yielded micro-batch carries a different + number of training examples (varying pack sizes). The denom-aware + path computes Σ_i grad(sum_loss_i) / Σ_i size_i, which is the + gradient of mean(loss_over_all_examples) for *any* partition. Tests + span uniform, dominantly-one-pack, single-pack, and + many-small-packs partitions to catch regressions where the divisor + drifts off-by-one. + """ + total = sum(sizes) + model, x, y = self._make_model_and_data(total, seed=13) + + _, expected = nnx.value_and_grad(self._loss_mean)(model, x, y) + + grad_sum = nnx.value_and_grad(self._loss_sum) + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + start = 0 + for size in sizes: + end = start + size + _, g = grad_sum(model, x[start:end], y[start:end]) + acc.add(g, denom=jnp.asarray(float(size))) + start = end + + accumulated = _unwrap(acc.get()) + expected_arrays = _unwrap(expected) + jax.tree_util.tree_map( + lambda a, e: np.testing.assert_allclose(a, e, rtol=1e-6, atol=1e-6), + accumulated, + expected_arrays, + ) + + def test_explicit_denom_packed_matches_unpacked_concatenation_under_jit(self): + """Packed + denom-aware path under `nnx.jit`, against unpacked baseline. + + Mirrors the production sequence-packing flow: each "pack" is a + micro-batch of varying size, fed through a jitted grad-sum step and + accumulated with `denom=size`. The expected value is computed *on + the same model* via the mean-loss path, so any mismatch isolates the + accumulation math (not data setup). + """ + sizes = (2, 4, 1, 3, 6) + total = sum(sizes) + model, x, y = self._make_model_and_data(total, seed=21) + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + + @nnx.jit + def _packed_add(model, acc, x_b, y_b, denom): + _, g = nnx.value_and_grad(self._loss_sum)(model, x_b, y_b) + acc.add(g, denom=denom) + + start = 0 + for size in sizes: + end = start + size + _packed_add( + model, + acc, + x[start:end], + y[start:end], + jnp.asarray(float(size), jnp.float32), + ) + start = end + + accumulated = _unwrap(acc.get()) + _, expected = nnx.value_and_grad(self._loss_mean)(model, x, y) + expected_arrays = _unwrap(expected) + jax.tree_util.tree_map( + lambda a, e: np.testing.assert_allclose(a, e, rtol=1e-6, atol=1e-6), + accumulated, + expected_arrays, + ) + + def test_default_and_explicit_denom_agree_when_micro_batches_uniform(self): + """Sanity bridge: explicit denom with uniform sizes ≡ default mode. + + When every micro-batch has the same size, the default (mean-of-means) + path and the denom-aware (sum-of-sums / sum-of-sizes) path must + produce the same gradient. This sanity-checks that the unification + of `count` and `denom` into a single field hasn't introduced a + silent off-by-N (e.g. summing K vs K+1 in one of the branches). + """ + sizes = (4, 4, 4, 4) + total = sum(sizes) + model, x, y = self._make_model_and_data(total, seed=99) + + # Default (mean) path. + acc_default = peft_trainer.GradientAccumulator(model, nnx.Param) + grad_mean = nnx.value_and_grad(self._loss_mean) + for i, size in enumerate(sizes): + s, e = i * size, (i + 1) * size + _, g = grad_mean(model, x[s:e], y[s:e]) + acc_default.add(g) + default_out = _unwrap(acc_default.get()) + + # Explicit-denom path with uniform sizes. + acc_denom = peft_trainer.GradientAccumulator(model, nnx.Param) + grad_sum = nnx.value_and_grad(self._loss_sum) + start = 0 + for size in sizes: + end = start + size + _, g = grad_sum(model, x[start:end], y[start:end]) + acc_denom.add(g, denom=jnp.asarray(float(size))) + start = end + denom_out = _unwrap(acc_denom.get()) + + jax.tree_util.tree_map( + lambda a, b: np.testing.assert_allclose(a, b, rtol=1e-6, atol=1e-6), + default_out, + denom_out, + ) + + def test_reset_then_reuse_does_not_leak_state(self): + """After `reset()`, a second accumulation cycle must match a fresh acc. + + Guards against state leaking across reset boundaries — e.g. the + denom counter not zeroing, or `grads` keeping a residual that would + silently bias subsequent updates. + """ + sizes = (4, 4) + total = sum(sizes) + model, x, y = self._make_model_and_data(total, seed=33) + grad_mean = nnx.value_and_grad(self._loss_mean) + + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + # First cycle on unrelated data — must be erased by reset. + junk_x = jax.random.normal(jax.random.PRNGKey(101), (8, 4)) + junk_y = jax.random.normal(jax.random.PRNGKey(102), (8, 2)) + for i in range(2): + _, g = grad_mean( + model, junk_x[i * 4 : (i + 1) * 4], junk_y[i * 4 : (i + 1) * 4] + ) + acc.add(g) + acc.reset() + + # Second cycle on the real data after reset. + for i, size in enumerate(sizes): + s, e = i * size, (i + 1) * size + _, g = grad_mean(model, x[s:e], y[s:e]) + acc.add(g) + after_reset = _unwrap(acc.get()) + + # Reference: fresh accumulator on the same real data. + acc_fresh = peft_trainer.GradientAccumulator(model, nnx.Param) + for i, size in enumerate(sizes): + s, e = i * size, (i + 1) * size + _, g = grad_mean(model, x[s:e], y[s:e]) + acc_fresh.add(g) + fresh = _unwrap(acc_fresh.get()) + + jax.tree_util.tree_map( + lambda a, b: np.testing.assert_allclose(a, b, rtol=1e-6, atol=1e-6), + after_reset, + fresh, + ) + + @parameterized.named_parameters( + dict(testcase_name='bfloat16', dtype=jnp.bfloat16), + dict(testcase_name='float16', dtype=jnp.float16), + dict(testcase_name='float32', dtype=jnp.float32), + ) + def test_get_preserves_grad_dtype(self, dtype: jnp.dtype): + rngs = nnx.Rngs(0) + model = nnx.Linear( + in_features=4, out_features=2, rngs=rngs, param_dtype=dtype + ) + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + + grads = jax.tree_util.tree_map( + lambda v: type(v)(jnp.ones_like(v[...])), + nnx.state(model, nnx.Param), + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + acc.add(grads, denom=jnp.asarray(3.0, dtype=jnp.float32)) + out = acc.get() + + jax.tree_util.tree_map( + lambda v: self.assertEqual(v[...].dtype, dtype), + out, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + def test_cond_apply_vs_skip_branches_have_matching_dtypes_in_bf16(self): + rngs = nnx.Rngs(0) + model = nnx.Linear( + in_features=4, out_features=2, rngs=rngs, param_dtype=jnp.bfloat16 + ) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + acc = peft_trainer.GradientAccumulator(model, nnx.Param) + + x = jnp.ones((2, 4), dtype=jnp.bfloat16) + y = jnp.ones((2, 2), dtype=jnp.bfloat16) + _, grads = nnx.value_and_grad(lambda m, x, y: jnp.sum((m(x) - y) ** 2))( + model, x, y + ) + acc.add(grads, denom=jnp.asarray(1.0, dtype=jnp.float32)) + + def apply_updates(model, optimizer, acc): + acc_grads = acc.get() + optimizer.update(model, acc_grads) + acc.reset() + return jnp.asarray(0.0, dtype=jnp.float32) + + def skip_updates(model, optimizer, acc): + return jnp.asarray(0.0, dtype=jnp.float32) + + @nnx.jit + def step(model, optimizer, acc, is_update_step): + return nnx.cond( + is_update_step, apply_updates, skip_updates, model, optimizer, acc + ) + + step(model, optimizer, acc, jnp.asarray(False)) + step(model, optimizer, acc, jnp.asarray(True)) + + opt_state_dtypes = jax.tree_util.tree_leaves( + jax.tree_util.tree_map( + lambda v: v[...].dtype, + nnx.state(optimizer, nnx.optimizer.OptState), + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + ) + float_dtypes = [ + d for d in opt_state_dtypes if jnp.issubdtype(d, jnp.floating) + ] + self.assertNotEmpty(float_dtypes) + for d in float_dtypes: + self.assertEqual(d, jnp.bfloat16) + + def test_peft_trainer_promotes_bf16_opt_state_floats_to_float32(self): + """`PeftTrainer.__init__` casts float opt_state leaves to float32. + + `optax.adam` / `optax.adamw` promote their floating-point moments + (`mu`, `nu`) to float32 inside `update` whenever the learning rate is + a float32 tracer (as produced by `optax.inject_hyperparams`). This test + verifies that the trainer casts these to float32 in-place during init. + """ + rngs = nnx.Rngs(0) + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=rngs) + bf16_state = jax.tree.map( + lambda x: x.astype(jnp.bfloat16) + if jnp.issubdtype(x.dtype, jnp.floating) + else x, + nnx.state(model, nnx.Param), + ) + nnx.update(model, bf16_state) + + tx = optax.inject_hyperparams(optax.adamw, hyperparam_dtype=jnp.float32)( + learning_rate=1e-3 + ) + config = peft_trainer.TrainingConfig(eval_every_n_steps=100, max_steps=1) + trainer = peft_trainer.PeftTrainer(model, tx, config) + + opt_state_dtypes = jax.tree_util.tree_leaves( + jax.tree_util.tree_map( + lambda v: v[...].dtype, + nnx.state(trainer.optimizer, nnx.optimizer.OptState), + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + ) + float_dtypes = [ + d for d in opt_state_dtypes if jnp.issubdtype(d, jnp.floating) + ] + self.assertNotEmpty(float_dtypes) + for d in float_dtypes: + self.assertEqual(d, jnp.float32) + + if __name__ == '__main__': absltest.main() diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index b9bc3f65e..05b1bdddd 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -269,6 +269,7 @@ def _validate_rollout_config(self): "True for AgenticRLLearner if using vLLM engine. Please set this " "before initializing RLCluster." ) + def _compute_rewards( self, prompts: List[str], @@ -764,16 +765,20 @@ def train( train_data_gen = self._data_consumer_batch_generator( train_data_queue, train_micro_batch_size ) - if self._training_config.max_seq_token_per_tpu is not None: + is_packed = self._training_config.max_seq_token_per_tpu is not None + if is_packed: logging.info( "Using sequence packing with max_seq_token_per_tpu: %d", self._training_config.max_seq_token_per_tpu, ) train_data_gen = rl_utils.pack_sequences( - train_data_gen, self._training_config.max_seq_token_per_tpu + train_data_gen, + self._training_config.max_seq_token_per_tpu, + target_items_per_update=grad_acc_steps, ) - micro_batches_since_last_sync = 0 - micro_batches_per_full_batch = full_batch_size // train_micro_batch_size + update_steps_since_last_sync = 0 + update_steps_per_full_batch = full_batch_size // mini_batch_size + unpacked_micro_step_counter = 0 for train_micro_batch in train_data_gen: if ( self._training_config.max_steps @@ -807,6 +812,7 @@ def train( # GRPO returns a list with a single TrainExample. merged_train_micro_batch = train_examples[0] else: + # TODO(b/491970038): handle seq packing case differently merged_train_micro_batch = jax.tree.map( lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch ) @@ -860,8 +866,20 @@ async def _eval_runner_async(current_eval_orchestrator): ) # --- Weight Sync Logic --- - micro_batches_since_last_sync += 1 - if micro_batches_since_last_sync == micro_batches_per_full_batch: + if is_packed: + # `merged_train_micro_batch.is_update_step` is a 0-d jax scalar set + # by `pack_sequences`; pull the host-side value before deciding. + is_update = bool( + np.asarray(merged_train_micro_batch.is_update_step).item() + ) + else: + # Mirror `peft_trainer._train_step`'s derivation: + # `is_update_step` flips True every `grad_acc_steps` micro-batches. + unpacked_micro_step_counter += 1 + is_update = unpacked_micro_step_counter % grad_acc_steps == 0 + if is_update: + update_steps_since_last_sync += 1 + if update_steps_since_last_sync == update_steps_per_full_batch: global_step_time = time.time() - self._global_step_start_time logging.info( f"Global step {self.rl_cluster.global_steps} completed in" @@ -922,7 +940,7 @@ async def _eval_runner_async(current_eval_orchestrator): self.rl_cluster.perf_v2.export(), mode=rl_cluster_lib.Mode.TRAIN, ) - micro_batches_since_last_sync = 0 + update_steps_since_last_sync = 0 self._global_step_start_time = time.time() _ = producer_future.result() diff --git a/tunix/rl/common.py b/tunix/rl/common.py index a8d507848..f6279d5cc 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -105,6 +105,7 @@ class TrainExample: old_per_token_logps: jax.Array | None segment_ids: jax.Array | None = None segment_positions: jax.Array | None = None + is_update_step: jax.Array | None = None def compute_kl_divergence( diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 0c0bfc7f8..299705642 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -724,7 +724,9 @@ def queue_iterator(): self._training_config.max_seq_token_per_tpu, ) train_data_gen = rl_utils.pack_sequences( - train_data_gen, self._training_config.max_seq_token_per_tpu + train_data_gen, + self._training_config.max_seq_token_per_tpu, + target_items_per_update=grad_acc_steps, ) curr_eval_ds = None diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py index 66035fbb6..8c4da7860 100644 --- a/tunix/rl/utils.py +++ b/tunix/rl/utils.py @@ -335,13 +335,16 @@ def pack_sequences( item_iterator: Iterator[list[common.TrainExample]], max_token_budget: int, pad_id: int = 0, + target_items_per_update: int | None = None, ) -> Iterator[list[common.TrainExample]]: """Packs a stream of TrainExamples into 1D sequences up to a token budget.""" buffer = [] current_tokens = 0 example_cls = common.TrainExample + accumulated_items = 0 - def _flush_buffer() -> list[common.TrainExample]: + def _flush_buffer(is_update: bool = False) -> list[common.TrainExample]: + """Flushes the buffer into a list of TrainExamples.""" nonlocal buffer, current_tokens if not buffer: return [] @@ -429,6 +432,8 @@ def _pad(arr_list, val, length): if has_policy_version: kwargs["policy_version"] = buffer[0]["policy_version"] + kwargs["is_update_step"] = jnp.array(is_update, dtype=jnp.bool_) + packed_example = example_cls(**kwargs) # pytype: disable=wrong-keyword-args buffer.clear() @@ -436,6 +441,7 @@ def _pad(arr_list, val, length): return [packed_example] for item_list in item_iterator: + accumulated_items += 1 for example in item_list: example_cls = type(example) unpadded_items = unpad_train_example(example) @@ -453,13 +459,18 @@ def _pad(arr_list, val, length): continue if current_tokens + tokens > max_token_budget: - yield _flush_buffer() + # Flush normally. The final batch logic below will trigger is_update=True. + yield _flush_buffer(is_update=False) buffer.append(item) current_tokens += tokens + if target_items_per_update and accumulated_items >= target_items_per_update: + yield _flush_buffer(is_update=True) + accumulated_items = 0 + if buffer: - yield _flush_buffer() + yield _flush_buffer(is_update=True) VERIFY_UPDATE_PARAMS_KEY = "VERIFY_UPDATE_PARAMS_SRC_TO_TGT_MODULE_NAME" diff --git a/tunix/sft/peft_trainer.py b/tunix/sft/peft_trainer.py index 501e5739c..112f1ac76 100644 --- a/tunix/sft/peft_trainer.py +++ b/tunix/sft/peft_trainer.py @@ -164,6 +164,98 @@ def _calculate_global_batch_size(train_example: Any) -> int: ) +class GradientAccumulator(nnx.Module): + """Accumulates gradients over multiple micro-steps. + + Unifies standard (unweighted) micro-batch averaging with sequence packing + (weighted, denom-aware) accumulation. + + Averaging behavior (optax.MultiSteps semantics): + When `add(grads)` is called without a denom, each micro-step implicitly + adds 1.0 to the denominator. `get()` computes `Σ_grads / Σ_1`, which + is the exact mean of the micro-step gradients. This is mathematically + equivalent to a single optimization step on a batch of size `B = + micro_batch_size * grad_acc_steps` when the loss is a mean-reduction + (e.g., standard cross-entropy). + + Packing-aware behavior (Sum of Grads / Sum of Sizes): + Under sequence packing, each yielded micro-batch contains a varying + number of valid target tokens or training examples. The loss is + computed as an *unreduced sum* over the packed batch. Callers pass the + true size of the pack via `add(grads, denom=size)`. `get()` computes + `Σ_grad(sum_loss_i) / Σ_size_i`, recovering the true global mean + gradient across all items in the accumulated batch, avoiding the bias + introduced by averaging pre-scaled micro-batch gradients of unequal + sizes. + """ + + def __init__(self, model: nnx.Module, wrt: type[nnx.Variable]): + state = nnx.state(model, wrt) + self.grads = nnx.data(jax.tree_util.tree_map(jnp.zeros_like, state)) + self.denom = nnx.Variable(jnp.zeros((), dtype=jnp.float32)) + + def add(self, grads: Any, denom: jax.Array | None = None): + def _add(acc_var, g_var): + g = g_var[...] if isinstance(g_var, nnx.Variable) else g_var + acc_var[...] = acc_var[...] + g + + jax.tree_util.tree_map( + _add, + self.grads, + grads, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + if denom is None: + denom_val = jnp.asarray(1.0, dtype=jnp.float32) + else: + denom_val = denom.astype(jnp.float32) + self.denom[...] = self.denom[...] + denom_val + + def get(self): + scale = 1.0 / jnp.maximum(self.denom[...], jnp.asarray(1.0, jnp.float32)) + + return jax.tree_util.tree_map( + lambda v: type(v)(v[...] * scale.astype(v[...].dtype)), + self.grads, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + def reset(self): + def _zero_in_place(v): + v[...] = jnp.zeros_like(v[...]) + + jax.tree_util.tree_map( + _zero_in_place, + self.grads, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + self.denom[...] = jnp.zeros_like(self.denom[...]) + + +def _promote_opt_state_floats_to_float32(optimizer: nnx.Optimizer) -> None: + """Cast the optimizer state's floating-point leaves to float32 in-place. + + Args: + optimizer: The nnx.Optimizer instance whose state will be modified. + """ + + def _cast(v): + if isinstance(v, nnx.Variable): + val = v.value + if ( + hasattr(val, "dtype") + and jnp.issubdtype(val.dtype, jnp.floating) + and val.dtype != jnp.float32 + ): + v.value = val.astype(jnp.float32) + + opt_state = nnx.state(optimizer, nnx.optimizer.OptState) + jax.tree_util.tree_map( + _cast, opt_state, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + + class PeftTrainer: """PEFT trainer for LoRA. Only LoRA parameters are updated. @@ -174,6 +266,8 @@ class PeftTrainer: use `optax.schedules.inject_hyperparams` to inject learning rate as a hyperparameter. For example: ``optimizer = optax.schedules.inject_hyperparams(optax.sgd)(learning_rate=learning_rate_schedule)`` + grad_accumulator: The gradient accumulator to use for accumulating gradients + over multiple micro-steps. loss_fn: The loss function to use. eval_loss_fn: The loss function to use for evaluation. gen_model_input_fn: The function to generate model input from training @@ -186,7 +280,7 @@ class PeftTrainer: data_hooks: The data hooks to use. """ - supports_sequence_packing = False + supports_sequence_packing = True def __init__( self, @@ -209,14 +303,13 @@ def __init__( self.model = model self.config = training_config self._lora_enabled = utils.is_lora_enabled(self.model) - if training_config.gradient_accumulation_steps is not None: - optimizer = optax.MultiSteps( - optimizer, training_config.gradient_accumulation_steps - ) - if self._lora_enabled: - self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=nnx.LoRAParam) - else: - self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=nnx.Param) + wrt_target = nnx.LoRAParam if self._lora_enabled else nnx.Param + self.optimizer = nnx.Optimizer(self.model, optimizer, wrt=wrt_target) + # Promote floating-point leaves to float32 in-place to match the dtype of + # the optimizer update function branch (which is float32 due to + # `optax.inject_hyperparams`). + _promote_opt_state_floats_to_float32(self.optimizer) + self.grad_accumulator = GradientAccumulator(self.model, wrt_target) self.loss_fn = _default_loss_fn self.eval_loss_fn = _default_loss_fn @@ -329,14 +422,21 @@ def with_gen_model_input_fn( return self def _train_step( - self, model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any + self, + model: nnx.Module, + optimizer: nnx.Optimizer, + grad_accumulator: GradientAccumulator, + inputs: Any, + is_update_step: jax.Array, ) -> Tuple[ArrayLike, Any | None, ArrayLike]: """Main body for one train step. Args: model: The model to train. optimizer: The optimizer to use. + grad_accumulator: The gradient accumulator to use. inputs: The training input. + is_update_step: Whether to update the model. Returns: A tuple containing the loss, auxiliary data (or None if has_aux is False), @@ -350,8 +450,44 @@ def _train_step( has_aux=self._has_aux, ) out, grads = grad_fn(model, **inputs) - grad_norm = optax.global_norm(grads) - optimizer.update(model, grads) + + # TODO(b/491970038): update denom for sequence packing. + grad_accumulator.add(grads, denom=jnp.asarray(1.0, dtype=jnp.float32)) + + def apply_updates(model, optimizer, grad_accumulator): + acc_grads = grad_accumulator.get() + # Compute the norm in float32 to 1) match `skip_updates()` return type and + # meet the requirement of `nnx.cond` that both branches return the same + # dtype, 2) for production-size models the sum-of-squares over bf16 grads + # quickly exhausts bf16 and float32 is needed for numerical stability. + norm = optax.global_norm( + jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), acc_grads) + ) + optimizer.update(model, acc_grads) + grad_accumulator.reset() + return norm + + def skip_updates(model, optimizer, grad_accumulator): + return jnp.array(0.0, dtype=jnp.float32) + + # If the mesh is not empty, then we need to replicate the is_update_step + # across all devices to avoid deadlock so that all devices see the same + # update step. + mesh = pxla.thread_resources.env.physical_mesh + if not mesh.empty: + is_update_step = jax.lax.with_sharding_constraint( + is_update_step, jax.sharding.PartitionSpec() + ) + + grad_norm = nnx.cond( + is_update_step, + apply_updates, + skip_updates, + model, + optimizer, + grad_accumulator, + ) + if self._has_aux: loss, aux = out return loss, aux, grad_norm @@ -397,6 +533,21 @@ def _shard_optimizer(self, mesh: shd.Mesh) -> None: ) nnx.update(self.optimizer, optimizer_sharded_state) + wrt_target = nnx.LoRAParam if self._lora_enabled else nnx.Param + model_state = nnx.state(self.model, wrt_target) + model_pspecs = nnx.get_partition_spec(model_state) + + # Partition Gradients similar to the model + grads_sharded = jax.lax.with_sharding_constraint( + self.grad_accumulator.grads, model_pspecs + ) + self.grad_accumulator.grads = grads_sharded + + # Denominator is a scalar — replicate across all devices + self.grad_accumulator.denom[...] = jax.lax.with_sharding_constraint( + self.grad_accumulator.denom[...], jax.sharding.PartitionSpec() + ) + def jit_train_and_eval_step( self, skip_jit: bool = False, cache_nnx_graph: bool = False ): @@ -419,7 +570,7 @@ def jit_train_and_eval_step( if self._jitted_train_step_fn is None: self._shard_optimizer(pxla.thread_resources.env.physical_mesh) self._jitted_train_step_fn = nnx.jit( - train_step, donate_argnames=("optimizer",) + train_step, donate_argnames=("optimizer", "grad_accumulator") ) self._jitted_eval_step_fn = nnx.jit(eval_step) @@ -431,7 +582,10 @@ def maybe_cache_and_partial(f, *args): return functools.partial(f, *args) self._jitted_train_step_fn = maybe_cache_and_partial( - self._jitted_train_step_fn, self.model, self.optimizer + self._jitted_train_step_fn, + self.model, + self.optimizer, + self.grad_accumulator, ) self._jitted_eval_step_fn = maybe_cache_and_partial( self._jitted_eval_step_fn, self.model @@ -696,6 +850,28 @@ def train( perf_constants.MINI_BATCH: mini_batch, } + self._iter_steps += 1 + + is_update_step_val = None + if ( + isinstance(train_example, dict) + and "is_update_step" in train_example + ): + val = train_example["is_update_step"] + if val is not None: + is_update_step_val = bool(np.asarray(val).item()) + elif hasattr(train_example, "is_update_step"): + val = train_example.is_update_step + if val is not None: + is_update_step_val = bool(np.asarray(val).item()) + + if is_update_step_val is None: + is_update_step_val = ( + self._iter_steps + % self.config.get_with_default("gradient_accumulation_steps", 1) + == 0 + ) + with self._perf_tracer.span( "peft_train_step", pxla.thread_resources.env.physical_mesh.devices, @@ -704,7 +880,10 @@ def train( pxla.thread_resources.env.physical_mesh.devices, tags=tags, ) as span_v2: - train_loss, aux, grad_norm = train_step(train_example) + train_loss, aux, grad_norm = train_step( + train_example, + is_update_step=jnp.array(is_update_step_val, dtype=jnp.bool_), + ) span.device_end([train_loss]) span_v2.async_end([train_loss]) @@ -717,13 +896,8 @@ def train( ) # NB: put this after self._buffer_metrics is important self._post_process_train_step(aux) - self._iter_steps += 1 - if ( - self._iter_steps - % self.config.get_with_default("gradient_accumulation_steps", 1) - == 0 - ): + if is_update_step_val: self._train_steps += 1 self._write_train_metrics()