diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_grpo_learner_test.py index 6dcced541..8beeef183 100644 --- a/tests/rl/agentic/agentic_grpo_learner_test.py +++ b/tests/rl/agentic/agentic_grpo_learner_test.py @@ -450,13 +450,15 @@ def __call__(self, inputs, positions, cache, attention_mask): policy_loss_fn = function_registry.get_policy_loss_fn( algo_config.policy_loss_fn ) - loss, aux = policy_loss_fn( + loss_output = policy_loss_fn( model=MockModel(rngs=nnx.Rngs(0)), train_example=train_example, algo_config=algo_config, pad_id=0, eos_id=2, ) + loss = loss_output.primary_loss.compute() + aux = loss_output.aux_metrics chex.assert_shape(loss, ()) self.assertIn("kl", aux) @@ -535,7 +537,7 @@ def __call__(self, inputs, positions, cache, attention_mask): policy_loss_fn = function_registry.get_policy_loss_fn(config.policy_loss_fn) model = MockModel(rngs=nnx.Rngs(0)) - loss, _ = policy_loss_fn( + loss_output = policy_loss_fn( model=model, train_example=train_example, algo_config=config, @@ -567,6 +569,7 @@ def __call__(self, inputs, positions, cache, attention_mask): else: expected_loss = float(jnp.mean(per_sequence_loss)) + loss = loss_output.primary_loss.compute() np.testing.assert_allclose(loss, expected_loss, rtol=1e-6, atol=1e-6) def test_process_results_extracts_assistant_text(self): diff --git a/tests/rl/common_test.py b/tests/rl/common_test.py index 448405c7d..d019345c2 100644 --- a/tests/rl/common_test.py +++ b/tests/rl/common_test.py @@ -24,6 +24,11 @@ jax.config.update("jax_threefry_partitionable", False) +def _compute_loss(*args, **kwargs): + out = getattr(common, "aggregate_loss")(*args, **kwargs) + return out.compute() + + class CommonTest(parameterized.TestCase): @parameterized.named_parameters( @@ -446,7 +451,9 @@ def test_pad_to_length(self): expected_loss=(0.1 + 0.2) / 4.0 / 1.0, ), dict( - testcase_name="sequence_mean_token_sum_norm_partial_zero_mask_default", + testcase_name=( + "sequence_mean_token_sum_norm_partial_zero_mask_default" + ), loss_agg_mode="sequence-mean-token-sum-norm", per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]], completion_mask_list=[[1, 1], [0, 0]], @@ -496,7 +503,7 @@ def test_aggregate_loss_values( ): per_token_loss = jnp.array(per_token_loss_list) completion_mask = jnp.array(completion_mask_list) - actual_loss = common.aggregate_loss( + actual_loss = _compute_loss( per_token_loss, completion_mask, loss_agg_mode, **kwargs ) np.testing.assert_allclose(actual_loss, expected_loss, rtol=1e-6, atol=1e-6) @@ -505,7 +512,7 @@ def test_invalid_mode(self): with self.assertRaisesRegex( ValueError, "Unsupported loss aggregation mode" ): - common.aggregate_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode") + _compute_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode") @parameterized.named_parameters( dict( @@ -541,7 +548,7 @@ def test_invalid_mode(self): ) def test_invalid_norm(self, norm_val, loss_agg_mode): with self.assertRaisesRegex(ValueError, "Invalid 'norm' value"): - common.aggregate_loss( + _compute_loss( jnp.ones((2, 2)), jnp.ones((2, 2)), loss_agg_mode, @@ -567,7 +574,7 @@ def test_aggregate_loss_bf16(self): per_token_loss = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16) completion_mask = jnp.array([1, 1, 0], dtype=jnp.int32) - loss = common.aggregate_loss( + loss = _compute_loss( per_token_loss, completion_mask, loss_agg_mode="token-mean" ) self.assertEqual(loss.dtype, jnp.float32) diff --git a/tests/rl/grpo/dapo_learner_test.py b/tests/rl/grpo/dapo_learner_test.py index 71667f690..a20b1ff44 100644 --- a/tests/rl/grpo/dapo_learner_test.py +++ b/tests/rl/grpo/dapo_learner_test.py @@ -90,19 +90,23 @@ def test_diff_loss(self): ) # Call DAPO loss function - dapo_loss, dapo_aux = dapo_loss_fn_impl( + dapo_loss_output = dapo_loss_fn_impl( model, train_example, dapo_config, pad_id, eos_id ) + dapo_loss = dapo_loss_output.primary_loss.compute() + dapo_aux = dapo_loss_output.aux_metrics # Call GRPO loss function - grpo_loss, grpo_aux = grpo_loss_fn_impl( + grpo_loss_output = grpo_loss_fn_impl( model, train_example, grpo_config, pad_id, eos_id ) + grpo_loss = grpo_loss_output.primary_loss.compute() + grpo_aux = grpo_loss_output.aux_metrics # Assert that the loss values are different self.assertNotEqual( - dapo_loss.item(), - grpo_loss.item(), + dapo_loss, + grpo_loss, msg=( "DAPO and GRPO loss values should be different for the same input" " due to different loss aggregation logics." @@ -111,7 +115,9 @@ def test_diff_loss(self): self.assertIn("kl", dapo_aux) self.assertIn("kl", grpo_aux) - self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term. + self.assertEqual( + dapo_aux["kl"].compute(), 0.0 + ) # DAPO does not have KL term. class TestDAPOConfigPostInit(parameterized.TestCase): diff --git a/tests/rl/grpo/drgrpo_learner_test.py b/tests/rl/grpo/drgrpo_learner_test.py index 8146f216a..d092b38d1 100644 --- a/tests/rl/grpo/drgrpo_learner_test.py +++ b/tests/rl/grpo/drgrpo_learner_test.py @@ -124,9 +124,11 @@ def test_drgrpo_loss_fn(self): ) # Call DrGRPO loss function - drgrpo_loss, drgrpo_aux = drgrpo_loss_fn_impl( + drgrpo_loss_output = drgrpo_loss_fn_impl( model, train_example, drgrpo_config, pad_id, eos_id ) + drgrpo_loss = drgrpo_loss_output.primary_loss.compute() + drgrpo_aux = drgrpo_loss_output.aux_metrics self.assertIn("kl", drgrpo_aux) self.assertTrue(jnp.isfinite(drgrpo_loss).all()) diff --git a/tests/sft/dpo/dpo_trainer_test.py b/tests/sft/dpo/dpo_trainer_test.py index f926df854..6947e5a1f 100644 --- a/tests/sft/dpo/dpo_trainer_test.py +++ b/tests/sft/dpo/dpo_trainer_test.py @@ -270,14 +270,16 @@ def test_dpo_loss_fn(self): with mock.patch.object( common, "get_per_token_logps", return_value=jnp.array(per_token_logps) ): - loss, _ = dpo_lib.dpo_loss_fn( + loss_output = dpo_lib.dpo_loss_fn( model, train_example, beta=0.1, label_smoothing=0 ) + loss = loss_output.primary_loss.compute() np.testing.assert_allclose(loss, 0.753059, atol=1e-5) - loss, _ = dpo_lib.dpo_loss_fn( + loss_output = dpo_lib.dpo_loss_fn( model, train_example, beta=0.1, label_smoothing=0.3 ) + loss = loss_output.primary_loss.compute() np.testing.assert_allclose(loss, 0.925447, atol=1e-5) def test_dpo_prepare_inputs_for_strings(self): diff --git a/tests/sft/dpo/orpo_trainer_test.py b/tests/sft/dpo/orpo_trainer_test.py index 0eaa59509..c2df64446 100644 --- a/tests/sft/dpo/orpo_trainer_test.py +++ b/tests/sft/dpo/orpo_trainer_test.py @@ -253,13 +253,16 @@ def test_orpo_loss_fn(self): "compute_logps", return_value=(jnp.array(chosen_logps), jnp.array(rejected_logps)), ): - loss, aux = orpo_lib.dpo_loss_fn( + loss_output = orpo_lib.dpo_loss_fn( model, train_example, algorithm="orpo", lambda_orpo=0.1, label_smoothing=0, ) + loss = loss_output.primary_loss.compute() + aux = loss_output.aux_metrics + # Loss should be a scalar and finite self.assertEqual(loss.shape, ()) self.assertTrue(jnp.isfinite(loss)) diff --git a/tests/sft/peft_trainer_test.py b/tests/sft/peft_trainer_test.py index 831cd101b..3b4d2ebfb 100644 --- a/tests/sft/peft_trainer_test.py +++ b/tests/sft/peft_trainer_test.py @@ -34,6 +34,7 @@ from tunix.sft import hooks from tunix.sft import peft_trainer from tunix.sft import profiler +from tunix.sft import utils from tunix.tests import test_common as tc from tunix.utils import compat @@ -634,7 +635,67 @@ def _post_process_eval_step(self, aux): self.assertEqual(train_invoke, {'foo': 2, 'bar': 4}) self.assertEqual(eval_invoke, {'foo': 1, 'bar': 16}) + def test_loss_output_format(self): + def custom_loss_fn( + model: nnx.Module, + input_tokens: jax.Array, + input_mask: jax.Array, + positions: jax.Array, + attention_mask: jax.Array, + images: jax.Array | None = None, + ) -> utils.LossOutput: + del model, input_tokens, input_mask, positions, attention_mask, images + return utils.LossOutput( + primary_loss=utils.WeightedMetric( + jnp.array(2.0, dtype=jnp.float32), + jnp.array(2.0, dtype=jnp.float32), + ), + aux_metrics={ + 'foo': utils.WeightedMetric( + jnp.array(10.0, dtype=jnp.float32), + jnp.array(5.0, dtype=jnp.float32), + ), + 'bar': utils.WeightedMetric( + jnp.array(6.0, dtype=jnp.float32), + jnp.array(2.0, dtype=jnp.float32), + ), + }, + ) + + train_invoke = {'foo': 0.0, 'bar': 0.0} + eval_invoke = {'foo': 0.0, 'bar': 0.0} + + class CustomTrainer(peft_trainer.PeftTrainer): + + def _post_process_train_step(self, aux): + train_invoke['foo'] += aux['foo'] + train_invoke['bar'] += aux['bar'] + + def _post_process_eval_step(self, aux): + eval_invoke['foo'] += aux['foo'] + eval_invoke['bar'] += aux['bar'] + + config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100) + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) + + trainer = CustomTrainer(model, optax.sgd(1e-3), config) + trainer = trainer.with_gen_model_input_fn( + dummy_gen_model_input_fn + ).with_loss_fn( + custom_loss_fn + ) # Note: has_aux=False is default but LossOutput returns aux natively + + trainer.train(self.train_ds, self.eval_ds) + # The dataset provides 2 training steps. + # foo = 10.0 / 5.0 = 2.0 per step. + # bar = 6.0 / 2.0 = 3.0 per step. + self.assertEqual(train_invoke, {'foo': 4.0, 'bar': 6.0}) + + # Since eval_ds is length 2, it evaluates at step 2. + self.assertEqual(eval_invoke, {'foo': 8.0, 'bar': 12.0}) + def test_injected_params(self): + config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100) model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 654f6b4b3..9ba8cf244 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -49,9 +49,9 @@ from tunix.rl.agentic.environments import base_environment from tunix.rl.agentic.environments import task_environment from tunix.rl.ppo import ppo_helpers +from tunix.sft import utils as sft_utils from tunix.utils import trajectory_logger - TrainingInputT = agentic_rl_learner.TrainingInputT RewardFn = agentic_rl_learner.RewardFn MetricFn = agentic_rl_learner.MetricFn @@ -74,8 +74,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig): num_iterations: Number of GRPO iterations per batch (μ in the paper). beta: KL penalty coefficient. kl_loss_mode: Method for computing the KL loss. - force_compute_kl: Whether to force compute KL divergence for logging - even when it would normally be skipped (e.g., when beta is 0.0). + force_compute_kl: Whether to force compute KL divergence for logging even + when it would normally be skipped (e.g., when beta is 0.0). epsilon: PPO-style clipping epsilon. epsilon_high: PPO-style clipping epsilon upper bound. loss_algo: "grpo" or "gspo-token". @@ -251,8 +251,7 @@ def __init__( }) self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ lambda: "kl" - if self.algo_config.force_compute_kl - or self.algo_config.beta != 0.0 + if self.algo_config.force_compute_kl or self.algo_config.beta != 0.0 else None, ]) @@ -594,9 +593,7 @@ def grpo_loss_fn( else epsilon ) epsilon_c = ( - algo_config.epsilon_c - if hasattr(algo_config, "epsilon_c") - else 3.0 + algo_config.epsilon_c if hasattr(algo_config, "epsilon_c") else 3.0 ) loss_aggregation_mode = algo_config.loss_agg_mode @@ -633,7 +630,8 @@ def grpo_loss_fn( seq_importance_ratio = per_token_logps - old_per_token_logps # Record KL divergence before clipping. - ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask) + unreduced_ppo_kl = jnp.sum(-seq_importance_ratio * completion_mask) + token_denom = completion_mask.sum() seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0) @@ -661,9 +659,7 @@ def grpo_loss_fn( per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32) - clipped_fraction = ppo_helpers.masked_mean( - jnp.greater(pg_loss_2, pg_loss_1), completion_mask - ) + unreduced_clip_frac = jnp.sum(jnp.greater(pg_loss_2, pg_loss_1) * completion_mask) # dual-clip ppo loss pg_loss_3 = -epsilon_c * adv @@ -671,25 +667,30 @@ def grpo_loss_fn( # pg_clipfrac_lower measures how often dual-clip ppo kicks in. # It kicks in when the standard clipped loss is larger than pg_loss_3 # for instances with negative advantages. - unreduced_pg_clipfrac_lower = ( + per_token_pg_clipfrac_lower = ( (per_token_loss > pg_loss_3) & (adv < 0.0) ).astype(jnp.float32) - pg_clipfrac_lower = common.aggregate_loss( - unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode + unreduced_pg_clipfrac_lower = common.aggregate_loss( + per_token_pg_clipfrac_lower, completion_mask, loss_aggregation_mode ) pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss) per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss) - loss = common.aggregate_loss( + weighted_loss = common.aggregate_loss( per_token_loss, completion_mask, loss_aggregation_mode ) + aux = { - "kl": 0.0, - "kl_loss": 0.0, - "pg_loss": loss, - "pg_clipfrac": clipped_fraction, - "ppo_kl": ppo_kl, - "pg_clipfrac_lower": pg_clipfrac_lower, + "kl": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)), + "kl_loss": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)), + "pg_loss": weighted_loss, + "pg_clipfrac": sft_utils.WeightedMetric( + unreduced_clip_frac, token_denom, min_denom=1.0 + ), + "ppo_kl": sft_utils.WeightedMetric( + unreduced_ppo_kl, token_denom, min_denom=1.0 + ), + "pg_clipfrac_lower": unreduced_pg_clipfrac_lower, } # We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless # force_compute_kl is True). @@ -699,17 +700,19 @@ def grpo_loss_fn( train_example.ref_per_token_logps, algo_config.kl_loss_mode, ) - # Log mean KL. - aux["kl"] = jnp.astype( - (kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1), - jnp.float32, - ) - kl_loss = common.aggregate_loss( - kl, completion_mask, loss_aggregation_mode + unreduced_kl = jnp.astype(jnp.sum(kl * completion_mask), jnp.float32) + aux["kl"] = sft_utils.WeightedMetric( + unreduced_kl, token_denom, min_denom=1.0 ) + kl_loss = common.aggregate_loss(kl, completion_mask, loss_aggregation_mode) aux["kl_loss"] = kl_loss if beta is not None and beta != 0.0: - loss = loss + beta * kl_loss + weighted_loss = sft_utils.WeightedMetric( + weighted_loss.unreduced_sum + beta * kl_loss.unreduced_sum, + weighted_loss.denominator, + eps=weighted_loss.eps, + min_denom=weighted_loss.min_denom, + ) token_entropy = ppo_helpers.compute_entropy_from_logits(logits) entropy_loss = common.aggregate_loss( @@ -717,7 +720,7 @@ def grpo_loss_fn( ) aux["entropy"] = entropy_loss - return loss, aux + return sft_utils.LossOutput(primary_loss=weighted_loss, aux_metrics=aux) @function_registry.register_advantage_estimator("agentic_grpo") diff --git a/tunix/rl/common.py b/tunix/rl/common.py index a8d507848..755a70410 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -489,7 +489,7 @@ def aggregate_loss( completion_mask: jax.Array, loss_agg_mode: str, **kwargs: Any, -) -> jax.Array: +) -> utils.WeightedMetric: """Aggregate loss based on the loss aggregation mode. Args: @@ -508,15 +508,17 @@ def aggregate_loss( if loss_agg_mode == "token-mean": # sum all the token loss, and average by total number of completion tokens # in the batch - loss = (per_token_loss * completion_mask).sum() / ( - jnp.clip(completion_mask.sum(), min=1) - ) + unreduced_sum = (per_token_loss * completion_mask).sum() + denominator = completion_mask.sum() + min_denom = 1.0 elif loss_agg_mode == "sequence-mean-token-mean": seq_mask = completion_mask.sum(axis=-1) # per-sequence token count seq_loss = ((per_token_loss * completion_mask).sum(axis=-1)) / jnp.clip( - seq_mask, min=1 + seq_mask, min=1.0 ) - loss = seq_loss.sum() / non_zero_rows + unreduced_sum = seq_loss.sum() + denominator = non_zero_rows + min_denom = 1.0 elif loss_agg_mode == "sequence-mean-token-scale": # Look up custom normalization factor, default to max response length. norm = _check_get_norm(kwargs, per_token_loss.shape[-1]) @@ -525,21 +527,24 @@ def aggregate_loss( seq_loss = (per_token_loss * completion_mask).sum(axis=-1) / jnp.clip( norm, min=1e-6 ) - loss = seq_loss.sum() / non_zero_rows + unreduced_sum = seq_loss.sum() + denominator = non_zero_rows + min_denom = 1.0 elif loss_agg_mode == "seq-mean-token-sum": # 1) sum token losses within each sequence # 2) average only across sequences that have at least one valid token seq_loss = (per_token_loss * completion_mask).sum(axis=-1) seq_mask = (completion_mask.sum(axis=-1) > 0).astype(jnp.float32) - loss = (seq_loss * seq_mask).sum() / jnp.clip(seq_mask.sum(), min=1e-6) + unreduced_sum = (seq_loss * seq_mask).sum() + denominator = seq_mask.sum() + min_denom = 1e-6 elif loss_agg_mode == "sequence-mean-token-sum-norm": # Get custom normalization factor from kwargs, default to number of # non-empty rows. norm = _check_get_norm(kwargs, non_zero_rows) - - # Sum the per-sequence sums and normalize - # TODO(sizhi): Experiment with loss in precision if loss is fp16. - loss = (per_token_loss * completion_mask).sum() / jnp.clip(norm, min=1e-6) + unreduced_sum = (per_token_loss * completion_mask).sum() + denominator = norm + min_denom = 1e-6 else: raise ValueError( f"Unsupported loss aggregation mode: {loss_agg_mode}. Supported modes:" @@ -547,7 +552,11 @@ def aggregate_loss( " 'sequence-mean-token-scale', 'seq-mean-token-sum'," " 'sequence-mean-token-sum-norm'." ) - return loss + return utils.WeightedMetric( + jnp.asarray(unreduced_sum, dtype=jnp.float32), + jnp.asarray(denominator, dtype=jnp.float32), + min_denom=min_denom, + ) def _check_get_norm( diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index fd30a5660..4cc3b4ca4 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -31,6 +31,7 @@ from tunix.rl import function_registry from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import rl_learner +from tunix.sft import utils as sft_utils TrainingInputT = rl_learner.TrainingInputT RewardFn = rl_learner.RewardFn @@ -539,12 +540,12 @@ def grpo_loss_fn( # When unpacked, they are 1D [B]. adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1) + denominator = jnp.sum(completion_mask) + # Compute pg_clipfrac pg_losses_1 = -coef_1 * adv pg_losses_2 = -coef_2 * adv - pg_clipfrac = jnp.sum( - (pg_losses_2 > pg_losses_1) * completion_mask - ) / jnp.clip(jnp.sum(completion_mask), min=1) + unreduced_pg_clipfrac = jnp.sum((pg_losses_2 > pg_losses_1) * completion_mask) # TODO(tsbao): We should handle token level advantages. per_token_loss = -jnp.minimum( @@ -553,26 +554,24 @@ def grpo_loss_fn( ) # add KL penalty - mean_kl = 0.0 + unreduced_kl = 0.0 if beta is not None and beta != 0.0: kl = common.compute_kl_divergence( per_token_logps, train_example.ref_per_token_logps ) per_token_loss = per_token_loss + beta * kl - mean_kl = (kl * completion_mask).sum() / jnp.clip( - completion_mask.sum(), min=1 - ) + unreduced_kl = (kl * completion_mask).sum() aux = { - "kl": mean_kl, - "pg_clipfrac": pg_clipfrac, + "kl": sft_utils.WeightedMetric(unreduced_kl, denominator, min_denom=1.0), + "pg_clipfrac": sft_utils.WeightedMetric(unreduced_pg_clipfrac, denominator, min_denom=1.0), } - loss = common.aggregate_loss( + weighted_loss = common.aggregate_loss( per_token_loss, completion_mask, loss_aggregation_mode ) - return loss, aux + return sft_utils.LossOutput(primary_loss=weighted_loss, aux_metrics=aux) @function_registry.register_advantage_estimator("grpo") diff --git a/tunix/rl/ppo/ppo_learner.py b/tunix/rl/ppo/ppo_learner.py index 512d40675..b2c2a1928 100644 --- a/tunix/rl/ppo/ppo_learner.py +++ b/tunix/rl/ppo/ppo_learner.py @@ -31,6 +31,7 @@ from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import rl_learner from tunix.rl.ppo import ppo_helpers +from tunix.sft import utils as sft_utils TrainingInputT = rl_learner.TrainingInputT RewardFn = rl_learner.RewardFn @@ -567,7 +568,8 @@ def ppo_value_loss_fn( vpreds = vpreds[:, -logits_to_keep - 1 : -1] if segment_ids is not None: - # Pad the first token's value with 0.0, since it has no preceding token to predict it. + # Pad the first token's value with 0.0, since it has no preceding token to + # predict it. vpreds = jnp.pad(vpreds, ((0, 0), (1, 0)), constant_values=0.0) vpred_clipped = jnp.clip( vpreds, values - clip_range_value, values + clip_range_value @@ -576,17 +578,30 @@ def ppo_value_loss_fn( vf_losses2 = jnp.square(vpred_clipped - returns) clipped_vf_losses = jnp.maximum(vf_losses1, vf_losses2) - # "token mean" style of normalisation. - vf_loss = ppo_helpers.masked_mean(clipped_vf_losses, completion_mask) - vf_loss = 0.5 * vf_loss + + denominator = jnp.sum(completion_mask) + + unreduced_vf_loss = 0.5 * jnp.sum(clipped_vf_losses * completion_mask) + + unreduced_vpred_mean = jnp.sum(vpreds * completion_mask) + unreduced_vf_clipfrac = jnp.sum( + (vf_losses2 > vf_losses1).astype(jnp.float32) * completion_mask + ) aux = { - "vpred_mean": ppo_helpers.masked_mean(vpreds, completion_mask), - "vf_clipfrac": ppo_helpers.masked_mean( - (vf_losses2 > vf_losses1).astype(jnp.float32), completion_mask + "vpred_mean": sft_utils.WeightedMetric( + unreduced_vpred_mean, denominator, min_denom=1.0 + ), + "vf_clipfrac": sft_utils.WeightedMetric( + unreduced_vf_clipfrac, denominator, min_denom=1.0 ), } - return vf_loss, aux + return sft_utils.LossOutput( + primary_loss=sft_utils.WeightedMetric( + unreduced_vf_loss, denominator, min_denom=1.0 + ), + aux_metrics=aux, + ) @registry.register("policy_loss_fn", "ppo") @@ -636,6 +651,11 @@ def ppo_policy_loss_fn( pg_losses_2 = -ratio_clipped * advantages clip_pg_losses_1 = jnp.maximum(pg_losses_1, pg_losses_2) + # For logging. + unreduced_pg_clipfrac = jnp.sum( + (pg_losses_2 > pg_losses_1).astype(jnp.float32) * completion_mask + ) + # Dual-clip PPO to avoid negative-advantage policy updates pg_losses = clip_pg_losses_1 if use_dual_clip_ppo: @@ -645,36 +665,47 @@ def ppo_policy_loss_fn( pg_losses = jnp.where(advantages < 0.0, clip_pg_losses_2, clip_pg_losses_1) # For logging. - unreduced_pg_clipfrac_lower = ( - (clip_pg_losses_1 > pg_losses_3) & (advantages < 0.0) - ).astype(jnp.float32) - pg_clipfrac_lower = ppo_helpers.masked_mean( - unreduced_pg_clipfrac_lower, completion_mask + unreduced_pg_clipfrac_lower = jnp.sum( + ((clip_pg_losses_1 > pg_losses_3) & (advantages < 0.0)).astype( + jnp.float32 + ) * completion_mask ) + denominator = jnp.sum(completion_mask) + unreduced_policy_loss = jnp.sum(pg_losses * completion_mask) + # Logging aux = { - "pg_clipfrac": ppo_helpers.masked_mean( - (pg_losses_2 > pg_losses_1).astype(jnp.float32), completion_mask + "pg_clipfrac": sft_utils.WeightedMetric( + unreduced_pg_clipfrac, + denominator, + min_denom=1.0, ), } if use_dual_clip_ppo: - aux["pg_clipfrac_lower"] = pg_clipfrac_lower # pylint: disable=undefined-variable - - # "token mean" style of normalisation - policy_loss = ppo_helpers.masked_mean(pg_losses, completion_mask) + aux["pg_clipfrac_lower"] = sft_utils.WeightedMetric( + unreduced_pg_clipfrac_lower, + denominator, + min_denom=1.0, + ) # Compute entropy loss. if entropy_coef is not None and entropy_coef > 0.0: token_entropy = ppo_helpers.compute_entropy_from_logits(logits) - # "token mean" style of normalisation. - entropy_loss = ppo_helpers.masked_mean(token_entropy, completion_mask) - policy_loss -= entropy_coef * entropy_loss + unreduced_entropy = jnp.sum(token_entropy * completion_mask) + unreduced_policy_loss -= entropy_coef * unreduced_entropy # Logging - aux["loss/entropy"] = entropy_loss + aux["loss/entropy"] = sft_utils.WeightedMetric( + unreduced_entropy, denominator, min_denom=1.0 + ) - return policy_loss, aux + return sft_utils.LossOutput( + primary_loss=sft_utils.WeightedMetric( + unreduced_policy_loss, denominator, min_denom=1.0 + ), + aux_metrics=aux, + ) PpoConfig = PPOConfig diff --git a/tunix/sft/dpo/dpo_trainer.py b/tunix/sft/dpo/dpo_trainer.py index 2bc40c763..a4fcdf3d9 100644 --- a/tunix/sft/dpo/dpo_trainer.py +++ b/tunix/sft/dpo/dpo_trainer.py @@ -26,17 +26,16 @@ import jax.numpy as jnp import numpy as np import optax + # TODO(abheesht): We should move TokenizerAdapter outside `generate`. from tunix.generate import tokenizer_adapter from tunix.rl import common from tunix.sft import peft_trainer +from tunix.sft import utils as sft_utils from typing_extensions import override - RawImageType = ( - str - | np.ndarray - | list[str | np.ndarray | list[str | np.ndarray] | None] + str | np.ndarray | list[str | np.ndarray | list[str | np.ndarray] | None] ) @@ -318,11 +317,11 @@ def _prepare_inputs( # Duplicate images as well (for multimodal inputs only). images = training_input.images if images is not None: - images = jnp.concatenate([images, images], axis=0) + images = jnp.concatenate([images, images], axis=0) if hasattr(self.model, "get_attention_mask"): attention_mask = self.model.get_attention_mask( - input_ids, inputs_mask=mask + input_ids, inputs_mask=mask ) else: attention_mask = common.make_causal_attn_mask(mask) @@ -390,7 +389,7 @@ def dpo_loss_fn( beta: float = 0.1, lambda_orpo: float = 0.1, label_smoothing: float = 0.0, -) -> tuple[jax.Array, dict[str, jax.Array]]: +) -> sft_utils.LossOutput | tuple[jax.Array, dict[str, jax.Array]]: """DPO/ORPO loss function. Args: @@ -453,19 +452,47 @@ def dpo_loss_fn( # Compute odds ratio for logging odds_ratio = jnp.exp(log_odds) + denominator = jnp.array(batch_size, dtype=jnp.float32) aux = { - "rewards/chosen": chosen_rewards.mean(), - "rewards/rejected": rejected_rewards.mean(), - "rewards/margin": (chosen_rewards - rejected_rewards).mean(), - "rewards/accuracy": (chosen_rewards > rejected_rewards).mean(), - "log_probs/chosen": chosen_logps.mean(), - "log_probs/rejected": rejected_logps.mean(), - "odds_ratio": odds_ratio.mean(), - "sft_loss": sft_loss.mean(), - "or_loss": or_loss.mean(), + "rewards/chosen": sft_utils.WeightedMetric( + chosen_rewards.sum(), denominator, min_denom=1.0 + ), + "rewards/rejected": sft_utils.WeightedMetric( + rejected_rewards.sum(), denominator, min_denom=1.0 + ), + "rewards/margin": sft_utils.WeightedMetric( + (chosen_rewards - rejected_rewards).sum(), + denominator, + min_denom=1.0, + ), + "rewards/accuracy": sft_utils.WeightedMetric( + (chosen_rewards > rejected_rewards).sum(), + denominator, + min_denom=1.0, + ), + "log_probs/chosen": sft_utils.WeightedMetric( + chosen_logps.sum(), denominator, min_denom=1.0 + ), + "log_probs/rejected": sft_utils.WeightedMetric( + rejected_logps.sum(), denominator, min_denom=1.0 + ), + "odds_ratio": sft_utils.WeightedMetric( + odds_ratio.sum(), denominator, min_denom=1.0 + ), + "sft_loss": sft_utils.WeightedMetric( + sft_loss.sum(), denominator, min_denom=1.0 + ), + "or_loss": sft_utils.WeightedMetric( + or_loss.sum(), denominator, min_denom=1.0 + ), } - return total_loss.mean(), aux + return sft_utils.LossOutput( + primary_loss=sft_utils.WeightedMetric( + total_loss.sum(), denominator, min_denom=1.0 + ), + aux_metrics=aux, + ) else: # DPO loss chosen_log_ratio = chosen_logps @@ -484,16 +511,39 @@ def dpo_loss_fn( chosen_rewards = beta * chosen_log_ratio rejected_rewards = beta * rejected_log_ratio + batch_size = train_example.completion_mask.shape[0] // 2 + denominator = jnp.array(batch_size, dtype=jnp.float32) aux = { - "rewards/chosen": chosen_rewards.mean(), - "rewards/rejected": rejected_rewards.mean(), - "rewards/margin": (chosen_rewards - rejected_rewards).mean(), - "rewards/accuracy": (chosen_rewards > rejected_rewards).mean(), - "log_probs/chosen": chosen_logps.mean(), - "log_probs/rejected": rejected_logps.mean(), + "rewards/chosen": sft_utils.WeightedMetric( + chosen_rewards.sum(), denominator, min_denom=1.0 + ), + "rewards/rejected": sft_utils.WeightedMetric( + rejected_rewards.sum(), denominator, min_denom=1.0 + ), + "rewards/margin": sft_utils.WeightedMetric( + (chosen_rewards - rejected_rewards).sum(), + denominator, + min_denom=1.0, + ), + "rewards/accuracy": sft_utils.WeightedMetric( + (chosen_rewards > rejected_rewards).sum(), + denominator, + min_denom=1.0, + ), + "log_probs/chosen": sft_utils.WeightedMetric( + chosen_logps.sum(), denominator, min_denom=1.0 + ), + "log_probs/rejected": sft_utils.WeightedMetric( + rejected_logps.sum(), denominator, min_denom=1.0 + ), } - return losses.mean(), aux + return sft_utils.LossOutput( + primary_loss=sft_utils.WeightedMetric( + losses.sum(), denominator, min_denom=1.0 + ), + aux_metrics=aux, + ) def _generate_ids_and_masks( @@ -524,7 +574,7 @@ def _tokenize(input_string: str, tokenizer: Any) -> jax.Array: input_ids = tokenizer.encode(input_string) bos_tok = [tokenizer.bos_id()] if tokenizer.bos_id() else [] input_ids = jnp.array( - tokenizer.dedup_bos_ids(bos_tok + input_ids), dtype=jnp.int32 + tokenizer.dedup_bos_ids(bos_tok + input_ids), dtype=jnp.int32 ) return input_ids @@ -550,11 +600,13 @@ def _preprocess_dict( for field in tokenized_input_fields }) elif all( - field in training_input for field in data_input_fields if field != "images" + field in training_input + for field in data_input_fields + if field != "images" ): - return DataInput( - **{field: training_input.get(field, None) for field in data_input_fields} - ) + return DataInput(**{ + field: training_input.get(field, None) for field in data_input_fields + }) else: raise ValueError( "Training input must contain either tokenized fields " @@ -583,9 +635,9 @@ def process_dpo_record( Args: record: A dictionary, containing "prompts", "images", "chosen_responses", "rejected_responses" as keys. For text fields, the values can be a - single string, or a list of strings. For `"images"`, the fields can be - a path (str), a NumPy array, list of paths, list of arrays, list of - lists of paths/arrays, or just None. + single string, or a list of strings. For `"images"`, the fields can be a + path (str), a NumPy array, list of paths, list of arrays, list of lists + of paths/arrays, or just None. tokenizer: The tokenizer or processor to use for converting text into token IDs. max_prompt_length: The maximum length for the tokenized prompts. Any diff --git a/tunix/sft/peft_trainer.py b/tunix/sft/peft_trainer.py index b1f200178..b25bbe274 100644 --- a/tunix/sft/peft_trainer.py +++ b/tunix/sft/peft_trainer.py @@ -133,6 +133,17 @@ def loss(self): return np.mean(np.array([np.array(x) for x in self.losses])) +def _compute_legacy_aux(loss_output: utils.LossOutput) -> Dict[str, Any]: + """Computes legacy aux metrics from a LossOutput.""" + legacy_aux = {} + for k, v in loss_output.aux_metrics.items(): + if isinstance(v, utils.WeightedMetric): + legacy_aux[k] = v.compute() + else: + legacy_aux[k] = v + return legacy_aux + + def _calculate_global_batch_size(train_example: Any) -> int: """Calculates the global batch size from a training example. @@ -344,26 +355,49 @@ def _train_step( """ inputs = self.gen_model_input_fn(inputs) + @functools.wraps(self.loss_fn) + def diff_fn(model, *args, **kwargs): + out = self.loss_fn(model, *args, **kwargs) + if isinstance(out, utils.LossOutput): + return out.primary_loss.unreduced_sum, out + elif self._has_aux: + return out[0], out[1] + else: + return out, None + grad_fn = nnx.value_and_grad( - self.loss_fn, + diff_fn, argnums=nnx.DiffState(0, nnx.LoRAParam) if self._lora_enabled else 0, - has_aux=self._has_aux, + has_aux=True, ) - out, grads = grad_fn(model, **inputs) + (loss_val, aux), grads = grad_fn(model, **inputs) + + if isinstance(aux, utils.LossOutput): + # Scale the unreduced gradients using the metric's scale computation + scale = aux.primary_loss.compute_scale() + grads = jax.tree.map(lambda g: g * scale, grads) + + # Compute exactly equivalent legacy loss val + loss_val = aux.primary_loss.compute() + grad_norm = optax.global_norm(grads) optimizer.update(model, grads) - if self._has_aux: - loss, aux = out - return loss, aux, grad_norm + + if isinstance(aux, utils.LossOutput): + return loss_val, _compute_legacy_aux(aux), grad_norm + elif self._has_aux: + return loss_val, aux, grad_norm else: - return out, None, grad_norm + return loss_val, None, grad_norm def _eval_step( self, model: nnx.Module, inputs: Any ) -> ArrayLike | Tuple[ArrayLike, Any]: inputs = self.gen_model_input_fn(inputs) out = self.eval_loss_fn(model, **inputs) - if self._has_aux: + if isinstance(out, utils.LossOutput): + return out.primary_loss.compute(), _compute_legacy_aux(out) + elif self._has_aux: loss, aux = out return loss, aux else: @@ -375,7 +409,9 @@ def create_train_step_fn( """Creates the train step function.""" return self._train_step - def create_eval_step_fn(self) -> Callable[..., ArrayLike]: + def create_eval_step_fn( + self, + ) -> Callable[..., ArrayLike | Tuple[ArrayLike, Any]]: """Creates the eval step function.""" return self._eval_step @@ -856,7 +892,7 @@ def _default_loss_fn( positions: jax.Array, attention_mask: jax.Array, images: jax.Array | None = None, -) -> ArrayLike: +) -> utils.LossOutput | ArrayLike: """Default loss function for PEFT training.""" # Weird kwargs workaround because not all models support `images` right now. kwargs = {} if images is None else {"images": images} @@ -874,8 +910,12 @@ def _default_loss_fn( one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., None] # Define the normalization factor. - norm_factor = 1 / (jnp.sum(target_mask) + 1e-8) + denominator = jnp.sum(target_mask) # Return the negative log likelihood (NLL) loss. # Equivalent to: optax.softmax_cross_entropy(logits, one_hot).mean() - return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor + unreduced_loss = -jnp.sum(jax.nn.log_softmax(logits) * one_hot) + return utils.LossOutput( + primary_loss=utils.WeightedMetric(unreduced_loss, denominator, eps=1e-8), + aux_metrics={}, + ) diff --git a/tunix/sft/utils.py b/tunix/sft/utils.py index 7e678806d..773995ac4 100644 --- a/tunix/sft/utils.py +++ b/tunix/sft/utils.py @@ -178,3 +178,58 @@ def show_hbm_usage(title=""): used / limit, devices[i], ) + + +import flax.struct +from typing import Dict + + +@flax.struct.dataclass +class WeightedMetric: + """A metric that requires weighted reduction. + + Attributes: + unreduced_sum: The sum of the metric values. Should be a scalar (). + denominator: The weight or count of valid tokens/examples. Should be a + scalar (). + eps: Optional epsilon added to denominator for numerical stability. + min_denom: Optional minimum bound for the denominator. + """ + + unreduced_sum: jax.Array + denominator: jax.Array + eps: float | None = flax.struct.field(default=None, pytree_node=False) + min_denom: float | None = flax.struct.field(default=None, pytree_node=False) + + def compute_scale(self) -> jax.Array: + """Safely computes the scale factor (1 / denominator) with bounds.""" + denom = self.denominator + if self.eps is not None: + denom = denom + self.eps + if self.min_denom is not None: + denom = jnp.maximum(denom, self.min_denom) + + # JAX Safe Division: Prevent division-by-zero NaNs from poisoning gradients + # We replace 0s with 1.0 *before* dividing. + safe_denom = jnp.where(denom == 0, 1.0, denom) + + # Calculate scale, masking out pure zero denominators to 0.0 + scale = 1.0 / safe_denom + return jnp.where(denom == 0, 0.0, scale) + + def compute(self) -> jax.Array: + """Safely computes total / count with optional legacy equivalence bounds.""" + return self.unreduced_sum * self.compute_scale() + + +@flax.struct.dataclass +class LossOutput: + """Output of a loss function containing unreduced primary loss and aux metrics. + + Attributes: + primary_loss: The main loss to be optimized. + aux_metrics: A dictionary of auxiliary metrics. + """ + + primary_loss: WeightedMetric + aux_metrics: Dict[str, WeightedMetric]