Skip to content

Commit 79fd0e1

Browse files
s-noghabiThe tunix Authors
authored andcommitted
move to unreduced loss fn
PiperOrigin-RevId: 912139415
1 parent 43f9eaa commit 79fd0e1

14 files changed

Lines changed: 412 additions & 139 deletions

tests/rl/agentic/agentic_grpo_learner_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,15 @@ def __call__(self, inputs, positions, cache, attention_mask):
450450
policy_loss_fn = function_registry.get_policy_loss_fn(
451451
algo_config.policy_loss_fn
452452
)
453-
loss, aux = policy_loss_fn(
453+
loss_output = policy_loss_fn(
454454
model=MockModel(rngs=nnx.Rngs(0)),
455455
train_example=train_example,
456456
algo_config=algo_config,
457457
pad_id=0,
458458
eos_id=2,
459459
)
460+
loss = loss_output.primary_loss.compute()
461+
aux = loss_output.aux_metrics
460462
chex.assert_shape(loss, ())
461463
self.assertIn("kl", aux)
462464

@@ -535,7 +537,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
535537
policy_loss_fn = function_registry.get_policy_loss_fn(config.policy_loss_fn)
536538

537539
model = MockModel(rngs=nnx.Rngs(0))
538-
loss, _ = policy_loss_fn(
540+
loss_output = policy_loss_fn(
539541
model=model,
540542
train_example=train_example,
541543
algo_config=config,
@@ -567,6 +569,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
567569
else:
568570
expected_loss = float(jnp.mean(per_sequence_loss))
569571

572+
loss = loss_output.primary_loss.compute()
570573
np.testing.assert_allclose(loss, expected_loss, rtol=1e-6, atol=1e-6)
571574

572575
def test_process_results_extracts_assistant_text(self):

tests/rl/common_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
jax.config.update("jax_threefry_partitionable", False)
2525

2626

27+
def _compute_loss(*args, **kwargs):
28+
out = getattr(common, "aggregate_loss")(*args, **kwargs)
29+
return out.compute()
30+
31+
2732
class CommonTest(parameterized.TestCase):
2833

2934
@parameterized.named_parameters(
@@ -446,7 +451,9 @@ def test_pad_to_length(self):
446451
expected_loss=(0.1 + 0.2) / 4.0 / 1.0,
447452
),
448453
dict(
449-
testcase_name="sequence_mean_token_sum_norm_partial_zero_mask_default",
454+
testcase_name=(
455+
"sequence_mean_token_sum_norm_partial_zero_mask_default"
456+
),
450457
loss_agg_mode="sequence-mean-token-sum-norm",
451458
per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]],
452459
completion_mask_list=[[1, 1], [0, 0]],
@@ -496,7 +503,7 @@ def test_aggregate_loss_values(
496503
):
497504
per_token_loss = jnp.array(per_token_loss_list)
498505
completion_mask = jnp.array(completion_mask_list)
499-
actual_loss = common.aggregate_loss(
506+
actual_loss = _compute_loss(
500507
per_token_loss, completion_mask, loss_agg_mode, **kwargs
501508
)
502509
np.testing.assert_allclose(actual_loss, expected_loss, rtol=1e-6, atol=1e-6)
@@ -505,7 +512,7 @@ def test_invalid_mode(self):
505512
with self.assertRaisesRegex(
506513
ValueError, "Unsupported loss aggregation mode"
507514
):
508-
common.aggregate_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")
515+
_compute_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")
509516

510517
@parameterized.named_parameters(
511518
dict(
@@ -541,7 +548,7 @@ def test_invalid_mode(self):
541548
)
542549
def test_invalid_norm(self, norm_val, loss_agg_mode):
543550
with self.assertRaisesRegex(ValueError, "Invalid 'norm' value"):
544-
common.aggregate_loss(
551+
_compute_loss(
545552
jnp.ones((2, 2)),
546553
jnp.ones((2, 2)),
547554
loss_agg_mode,
@@ -567,7 +574,7 @@ def test_aggregate_loss_bf16(self):
567574
per_token_loss = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16)
568575
completion_mask = jnp.array([1, 1, 0], dtype=jnp.int32)
569576

570-
loss = common.aggregate_loss(
577+
loss = _compute_loss(
571578
per_token_loss, completion_mask, loss_agg_mode="token-mean"
572579
)
573580
self.assertEqual(loss.dtype, jnp.float32)

tests/rl/grpo/dapo_learner_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,23 @@ def test_diff_loss(self):
9090
)
9191

9292
# Call DAPO loss function
93-
dapo_loss, dapo_aux = dapo_loss_fn_impl(
93+
dapo_loss_output = dapo_loss_fn_impl(
9494
model, train_example, dapo_config, pad_id, eos_id
9595
)
96+
dapo_loss = dapo_loss_output.primary_loss.compute()
97+
dapo_aux = dapo_loss_output.aux_metrics
9698

9799
# Call GRPO loss function
98-
grpo_loss, grpo_aux = grpo_loss_fn_impl(
100+
grpo_loss_output = grpo_loss_fn_impl(
99101
model, train_example, grpo_config, pad_id, eos_id
100102
)
103+
grpo_loss = grpo_loss_output.primary_loss.compute()
104+
grpo_aux = grpo_loss_output.aux_metrics
101105

102106
# Assert that the loss values are different
103107
self.assertNotEqual(
104-
dapo_loss.item(),
105-
grpo_loss.item(),
108+
dapo_loss,
109+
grpo_loss,
106110
msg=(
107111
"DAPO and GRPO loss values should be different for the same input"
108112
" due to different loss aggregation logics."
@@ -111,7 +115,9 @@ def test_diff_loss(self):
111115

112116
self.assertIn("kl", dapo_aux)
113117
self.assertIn("kl", grpo_aux)
114-
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.
118+
self.assertEqual(
119+
dapo_aux["kl"].compute(), 0.0
120+
) # DAPO does not have KL term.
115121

116122

117123
class TestDAPOConfigPostInit(parameterized.TestCase):

tests/rl/grpo/drgrpo_learner_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,11 @@ def test_drgrpo_loss_fn(self):
124124
)
125125

126126
# Call DrGRPO loss function
127-
drgrpo_loss, drgrpo_aux = drgrpo_loss_fn_impl(
127+
drgrpo_loss_output = drgrpo_loss_fn_impl(
128128
model, train_example, drgrpo_config, pad_id, eos_id
129129
)
130+
drgrpo_loss = drgrpo_loss_output.primary_loss.compute()
131+
drgrpo_aux = drgrpo_loss_output.aux_metrics
130132

131133
self.assertIn("kl", drgrpo_aux)
132134
self.assertTrue(jnp.isfinite(drgrpo_loss).all())

tests/sft/dpo/dpo_trainer_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,16 @@ def test_dpo_loss_fn(self):
270270
with mock.patch.object(
271271
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
272272
):
273-
loss, _ = dpo_lib.dpo_loss_fn(
273+
loss_output = dpo_lib.dpo_loss_fn(
274274
model, train_example, beta=0.1, label_smoothing=0
275275
)
276+
loss = loss_output.primary_loss.compute()
276277
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)
277278

278-
loss, _ = dpo_lib.dpo_loss_fn(
279+
loss_output = dpo_lib.dpo_loss_fn(
279280
model, train_example, beta=0.1, label_smoothing=0.3
280281
)
282+
loss = loss_output.primary_loss.compute()
281283
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)
282284

283285
def test_dpo_prepare_inputs_for_strings(self):

tests/sft/dpo/orpo_trainer_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,16 @@ def test_orpo_loss_fn(self):
253253
"compute_logps",
254254
return_value=(jnp.array(chosen_logps), jnp.array(rejected_logps)),
255255
):
256-
loss, aux = orpo_lib.dpo_loss_fn(
256+
loss_output = orpo_lib.dpo_loss_fn(
257257
model,
258258
train_example,
259259
algorithm="orpo",
260260
lambda_orpo=0.1,
261261
label_smoothing=0,
262262
)
263+
loss = loss_output.primary_loss.compute()
264+
aux = loss_output.aux_metrics
265+
263266
# Loss should be a scalar and finite
264267
self.assertEqual(loss.shape, ())
265268
self.assertTrue(jnp.isfinite(loss))

tests/sft/peft_trainer_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tunix.sft import hooks
3535
from tunix.sft import peft_trainer
3636
from tunix.sft import profiler
37+
from tunix.sft import utils
3738
from tunix.tests import test_common as tc
3839
from tunix.utils import compat
3940

@@ -634,7 +635,67 @@ def _post_process_eval_step(self, aux):
634635
self.assertEqual(train_invoke, {'foo': 2, 'bar': 4})
635636
self.assertEqual(eval_invoke, {'foo': 1, 'bar': 16})
636637

638+
def test_loss_output_format(self):
639+
def custom_loss_fn(
640+
model: nnx.Module,
641+
input_tokens: jax.Array,
642+
input_mask: jax.Array,
643+
positions: jax.Array,
644+
attention_mask: jax.Array,
645+
images: jax.Array | None = None,
646+
) -> utils.LossOutput:
647+
del model, input_tokens, input_mask, positions, attention_mask, images
648+
return utils.LossOutput(
649+
primary_loss=utils.WeightedMetric(
650+
jnp.array(2.0, dtype=jnp.float32),
651+
jnp.array(2.0, dtype=jnp.float32),
652+
),
653+
aux_metrics={
654+
'foo': utils.WeightedMetric(
655+
jnp.array(10.0, dtype=jnp.float32),
656+
jnp.array(5.0, dtype=jnp.float32),
657+
),
658+
'bar': utils.WeightedMetric(
659+
jnp.array(6.0, dtype=jnp.float32),
660+
jnp.array(2.0, dtype=jnp.float32),
661+
),
662+
},
663+
)
664+
665+
train_invoke = {'foo': 0.0, 'bar': 0.0}
666+
eval_invoke = {'foo': 0.0, 'bar': 0.0}
667+
668+
class CustomTrainer(peft_trainer.PeftTrainer):
669+
670+
def _post_process_train_step(self, aux):
671+
train_invoke['foo'] += aux['foo']
672+
train_invoke['bar'] += aux['bar']
673+
674+
def _post_process_eval_step(self, aux):
675+
eval_invoke['foo'] += aux['foo']
676+
eval_invoke['bar'] += aux['bar']
677+
678+
config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
679+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
680+
681+
trainer = CustomTrainer(model, optax.sgd(1e-3), config)
682+
trainer = trainer.with_gen_model_input_fn(
683+
dummy_gen_model_input_fn
684+
).with_loss_fn(
685+
custom_loss_fn
686+
) # Note: has_aux=False is default but LossOutput returns aux natively
687+
688+
trainer.train(self.train_ds, self.eval_ds)
689+
# The dataset provides 2 training steps.
690+
# foo = 10.0 / 5.0 = 2.0 per step.
691+
# bar = 6.0 / 2.0 = 3.0 per step.
692+
self.assertEqual(train_invoke, {'foo': 4.0, 'bar': 6.0})
693+
694+
# Since eval_ds is length 2, it evaluates at step 2.
695+
self.assertEqual(eval_invoke, {'foo': 8.0, 'bar': 12.0})
696+
637697
def test_injected_params(self):
698+
638699
config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
639700
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
640701

tunix/rl/agentic/agentic_grpo_learner.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
from tunix.rl.agentic.environments import base_environment
5050
from tunix.rl.agentic.environments import task_environment
5151
from tunix.rl.ppo import ppo_helpers
52+
from tunix.sft import utils as sft_utils
5253
from tunix.utils import trajectory_logger
5354

54-
5555
TrainingInputT = agentic_rl_learner.TrainingInputT
5656
RewardFn = agentic_rl_learner.RewardFn
5757
MetricFn = agentic_rl_learner.MetricFn
@@ -74,8 +74,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
7474
num_iterations: Number of GRPO iterations per batch (μ in the paper).
7575
beta: KL penalty coefficient.
7676
kl_loss_mode: Method for computing the KL loss.
77-
force_compute_kl: Whether to force compute KL divergence for logging
78-
even when it would normally be skipped (e.g., when beta is 0.0).
77+
force_compute_kl: Whether to force compute KL divergence for logging even
78+
when it would normally be skipped (e.g., when beta is 0.0).
7979
epsilon: PPO-style clipping epsilon.
8080
epsilon_high: PPO-style clipping epsilon upper bound.
8181
loss_algo: "grpo" or "gspo-token".
@@ -251,8 +251,7 @@ def __init__(
251251
})
252252
self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([
253253
lambda: "kl"
254-
if self.algo_config.force_compute_kl
255-
or self.algo_config.beta != 0.0
254+
if self.algo_config.force_compute_kl or self.algo_config.beta != 0.0
256255
else None,
257256
])
258257

@@ -594,9 +593,7 @@ def grpo_loss_fn(
594593
else epsilon
595594
)
596595
epsilon_c = (
597-
algo_config.epsilon_c
598-
if hasattr(algo_config, "epsilon_c")
599-
else 3.0
596+
algo_config.epsilon_c if hasattr(algo_config, "epsilon_c") else 3.0
600597
)
601598
loss_aggregation_mode = algo_config.loss_agg_mode
602599

@@ -633,7 +630,8 @@ def grpo_loss_fn(
633630

634631
seq_importance_ratio = per_token_logps - old_per_token_logps
635632
# Record KL divergence before clipping.
636-
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)
633+
unreduced_ppo_kl = jnp.sum(-seq_importance_ratio * completion_mask)
634+
token_denom = completion_mask.sum()
637635

638636
seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)
639637

@@ -661,35 +659,38 @@ def grpo_loss_fn(
661659

662660
per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)
663661

664-
clipped_fraction = ppo_helpers.masked_mean(
665-
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
666-
)
662+
unreduced_clip_frac = jnp.sum(jnp.greater(pg_loss_2, pg_loss_1) * completion_mask)
667663

668664
# dual-clip ppo loss
669665
pg_loss_3 = -epsilon_c * adv
670666

671667
# pg_clipfrac_lower measures how often dual-clip ppo kicks in.
672668
# It kicks in when the standard clipped loss is larger than pg_loss_3
673669
# for instances with negative advantages.
674-
unreduced_pg_clipfrac_lower = (
670+
per_token_pg_clipfrac_lower = (
675671
(per_token_loss > pg_loss_3) & (adv < 0.0)
676672
).astype(jnp.float32)
677-
pg_clipfrac_lower = common.aggregate_loss(
678-
unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
673+
unreduced_pg_clipfrac_lower = common.aggregate_loss(
674+
per_token_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
679675
)
680676

681677
pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss)
682678
per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss)
683-
loss = common.aggregate_loss(
679+
weighted_loss = common.aggregate_loss(
684680
per_token_loss, completion_mask, loss_aggregation_mode
685681
)
682+
686683
aux = {
687-
"kl": 0.0,
688-
"kl_loss": 0.0,
689-
"pg_loss": loss,
690-
"pg_clipfrac": clipped_fraction,
691-
"ppo_kl": ppo_kl,
692-
"pg_clipfrac_lower": pg_clipfrac_lower,
684+
"kl": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)),
685+
"kl_loss": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)),
686+
"pg_loss": weighted_loss,
687+
"pg_clipfrac": sft_utils.WeightedMetric(
688+
unreduced_clip_frac, token_denom, min_denom=1.0
689+
),
690+
"ppo_kl": sft_utils.WeightedMetric(
691+
unreduced_ppo_kl, token_denom, min_denom=1.0
692+
),
693+
"pg_clipfrac_lower": unreduced_pg_clipfrac_lower,
693694
}
694695
# We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless
695696
# force_compute_kl is True).
@@ -699,25 +700,27 @@ def grpo_loss_fn(
699700
train_example.ref_per_token_logps,
700701
algo_config.kl_loss_mode,
701702
)
702-
# Log mean KL.
703-
aux["kl"] = jnp.astype(
704-
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
705-
jnp.float32,
706-
)
707-
kl_loss = common.aggregate_loss(
708-
kl, completion_mask, loss_aggregation_mode
703+
unreduced_kl = jnp.astype(jnp.sum(kl * completion_mask), jnp.float32)
704+
aux["kl"] = sft_utils.WeightedMetric(
705+
unreduced_kl, token_denom, min_denom=1.0
709706
)
707+
kl_loss = common.aggregate_loss(kl, completion_mask, loss_aggregation_mode)
710708
aux["kl_loss"] = kl_loss
711709
if beta is not None and beta != 0.0:
712-
loss = loss + beta * kl_loss
710+
weighted_loss = sft_utils.WeightedMetric(
711+
weighted_loss.unreduced_sum + beta * kl_loss.unreduced_sum,
712+
weighted_loss.denominator,
713+
eps=weighted_loss.eps,
714+
min_denom=weighted_loss.min_denom,
715+
)
713716

714717
token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
715718
entropy_loss = common.aggregate_loss(
716719
token_entropy, completion_mask, loss_aggregation_mode
717720
)
718721
aux["entropy"] = entropy_loss
719722

720-
return loss, aux
723+
return sft_utils.LossOutput(primary_loss=weighted_loss, aux_metrics=aux)
721724

722725

723726
@function_registry.register_advantage_estimator("agentic_grpo")

0 commit comments

Comments
 (0)