Skip to content

Commit 4e1f0d4

Browse files
s-noghabiThe tunix Authors
authored andcommitted
[seq packing] move to unreduced loss
PiperOrigin-RevId: 914999211
1 parent 30fec08 commit 4e1f0d4

12 files changed

Lines changed: 407 additions & 123 deletions

tests/rl/agentic/agentic_grpo_learner_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,15 @@ def __call__(self, inputs, positions, cache, attention_mask):
554554
policy_loss_fn = function_registry.get_policy_loss_fn(
555555
algo_config.policy_loss_fn
556556
)
557-
loss, aux = policy_loss_fn(
557+
loss_output = policy_loss_fn(
558558
model=MockModel(rngs=nnx.Rngs(0)),
559559
train_example=train_example,
560560
algo_config=algo_config,
561561
pad_id=0,
562562
eos_id=2,
563563
)
564+
loss = loss_output.primary_loss.compute()
565+
aux = loss_output.aux_metrics
564566
chex.assert_shape(loss, ())
565567
self.assertIn("kl", aux)
566568

@@ -639,7 +641,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
639641
policy_loss_fn = function_registry.get_policy_loss_fn(config.policy_loss_fn)
640642

641643
model = MockModel(rngs=nnx.Rngs(0))
642-
loss, _ = policy_loss_fn(
644+
loss_output = policy_loss_fn(
643645
model=model,
644646
train_example=train_example,
645647
algo_config=config,
@@ -671,10 +673,12 @@ def __call__(self, inputs, positions, cache, attention_mask):
671673
else:
672674
expected_loss = float(jnp.mean(per_sequence_loss))
673675

676+
loss = loss_output.primary_loss.compute()
674677
np.testing.assert_allclose(loss, expected_loss, rtol=1e-6, atol=1e-6)
675678

676679
def test_process_results_extracts_assistant_text(self):
677680
class MockTraj:
681+
678682
def __init__(self, index):
679683
self.traj = {
680684
"conversation_text": [
@@ -695,6 +699,7 @@ def __init__(self, index):
695699
trajectories = [MockTraj(0), MockTraj(1)]
696700

697701
extracted_completions = []
702+
698703
def mock_compute_rewards(prompts, completions, **kwargs):
699704
extracted_completions.extend(completions)
700705
return jnp.ones(len(completions), dtype=jnp.float32)
@@ -748,7 +753,9 @@ def mock_compute_rewards(prompts, completions, **kwargs):
748753
chat_parser=MockChatParser(),
749754
)
750755

751-
with mock.patch.object(learner, "_compute_rewards", side_effect=mock_compute_rewards):
756+
with mock.patch.object(
757+
learner, "_compute_rewards", side_effect=mock_compute_rewards
758+
):
752759
with mock.patch.object(
753760
learner.rl_cluster,
754761
"get_ref_per_token_logps",

tests/rl/common_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
jax.config.update("jax_threefry_partitionable", False)
2525

2626

27+
def _compute_loss(*args, **kwargs):
28+
out = getattr(common, "aggregate_loss")(*args, **kwargs)
29+
return out.compute()
30+
31+
2732
class CommonTest(parameterized.TestCase):
2833

2934
@parameterized.named_parameters(
@@ -446,7 +451,9 @@ def test_pad_to_length(self):
446451
expected_loss=(0.1 + 0.2) / 4.0 / 1.0,
447452
),
448453
dict(
449-
testcase_name="sequence_mean_token_sum_norm_partial_zero_mask_default",
454+
testcase_name=(
455+
"sequence_mean_token_sum_norm_partial_zero_mask_default"
456+
),
450457
loss_agg_mode="sequence-mean-token-sum-norm",
451458
per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]],
452459
completion_mask_list=[[1, 1], [0, 0]],
@@ -496,7 +503,7 @@ def test_aggregate_loss_values(
496503
):
497504
per_token_loss = jnp.array(per_token_loss_list)
498505
completion_mask = jnp.array(completion_mask_list)
499-
actual_loss = common.aggregate_loss(
506+
actual_loss = _compute_loss(
500507
per_token_loss, completion_mask, loss_agg_mode, **kwargs
501508
)
502509
np.testing.assert_allclose(actual_loss, expected_loss, rtol=1e-6, atol=1e-6)
@@ -505,7 +512,7 @@ def test_invalid_mode(self):
505512
with self.assertRaisesRegex(
506513
ValueError, "Unsupported loss aggregation mode"
507514
):
508-
common.aggregate_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")
515+
_compute_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")
509516

510517
@parameterized.named_parameters(
511518
dict(
@@ -541,7 +548,7 @@ def test_invalid_mode(self):
541548
)
542549
def test_invalid_norm(self, norm_val, loss_agg_mode):
543550
with self.assertRaisesRegex(ValueError, "Invalid 'norm' value"):
544-
common.aggregate_loss(
551+
_compute_loss(
545552
jnp.ones((2, 2)),
546553
jnp.ones((2, 2)),
547554
loss_agg_mode,
@@ -567,7 +574,7 @@ def test_aggregate_loss_bf16(self):
567574
per_token_loss = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16)
568575
completion_mask = jnp.array([1, 1, 0], dtype=jnp.int32)
569576

570-
loss = common.aggregate_loss(
577+
loss = _compute_loss(
571578
per_token_loss, completion_mask, loss_agg_mode="token-mean"
572579
)
573580
self.assertEqual(loss.dtype, jnp.float32)

tests/rl/grpo/dapo_learner_test.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,27 @@ def test_diff_loss(self):
8989
rngs=nnx.Rngs(0),
9090
)
9191

92-
# Call DAPO loss function (DAPO sets ref_per_token_logps to None as it doesn't fetch it)
92+
# Call DAPO loss function (DAPO sets ref_per_token_logps to None as it
93+
# doesn't fetch it).
9394
dapo_train_example = self.create_train_example()
9495
dapo_train_example.ref_per_token_logps = None
95-
dapo_loss, dapo_aux = dapo_loss_fn_impl(
96+
dapo_loss_output = dapo_loss_fn_impl(
9697
model, dapo_train_example, dapo_config, pad_id, eos_id
9798
)
99+
dapo_loss = dapo_loss_output.primary_loss.compute()
100+
dapo_aux = dapo_loss_output.aux_metrics
98101

99102
# Call GRPO loss function
100-
grpo_loss, grpo_aux = grpo_loss_fn_impl(
103+
grpo_loss_output = grpo_loss_fn_impl(
101104
model, train_example, grpo_config, pad_id, eos_id
102105
)
106+
grpo_loss = grpo_loss_output.primary_loss.compute()
107+
grpo_aux = grpo_loss_output.aux_metrics
103108

104109
# Assert that the loss values are different
105110
self.assertNotEqual(
106-
dapo_loss.item(),
107-
grpo_loss.item(),
111+
dapo_loss,
112+
grpo_loss,
108113
msg=(
109114
"DAPO and GRPO loss values should be different for the same input"
110115
" due to different loss aggregation logics."
@@ -113,7 +118,9 @@ def test_diff_loss(self):
113118

114119
self.assertIn("kl", dapo_aux)
115120
self.assertIn("kl", grpo_aux)
116-
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.
121+
self.assertEqual(
122+
dapo_aux["kl"].compute(), 0.0
123+
) # DAPO does not have KL term.
117124

118125

119126
class TestDAPOConfigPostInit(parameterized.TestCase):

tests/rl/grpo/drgrpo_learner_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,11 @@ def test_drgrpo_loss_fn(self):
125125
)
126126

127127
# Call DrGRPO loss function
128-
drgrpo_loss, drgrpo_aux = drgrpo_loss_fn_impl(
128+
drgrpo_loss_output = drgrpo_loss_fn_impl(
129129
model, train_example, drgrpo_config, pad_id, eos_id
130130
)
131+
drgrpo_loss = drgrpo_loss_output.primary_loss.compute()
132+
drgrpo_aux = drgrpo_loss_output.aux_metrics
131133

132134
self.assertIn("kl", drgrpo_aux)
133135
self.assertTrue(jnp.isfinite(drgrpo_loss).all())

tests/sft/dpo/dpo_trainer_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,16 @@ def test_dpo_loss_fn(self):
270270
with mock.patch.object(
271271
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
272272
):
273-
loss, _ = dpo_lib.dpo_loss_fn(
273+
loss_output = dpo_lib.dpo_loss_fn(
274274
model, train_example, beta=0.1, label_smoothing=0
275275
)
276+
loss = loss_output.primary_loss.compute()
276277
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)
277278

278-
loss, _ = dpo_lib.dpo_loss_fn(
279+
loss_output = dpo_lib.dpo_loss_fn(
279280
model, train_example, beta=0.1, label_smoothing=0.3
280281
)
282+
loss = loss_output.primary_loss.compute()
281283
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)
282284

283285
def test_dpo_prepare_inputs_for_strings(self):

tests/sft/dpo/orpo_trainer_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def test_orpo_trainer(
161161
orpo_trainer._train_steps,
162162
)
163163
self.assertLen(
164-
orpo_trainer.metrics_logger.get_metric_history("", metric_name, "eval"),
164+
orpo_trainer.metrics_logger.get_metric_history(
165+
"", metric_name, "eval"
166+
),
165167
3,
166168
)
167169

@@ -253,13 +255,16 @@ def test_orpo_loss_fn(self):
253255
"compute_logps",
254256
return_value=(jnp.array(chosen_logps), jnp.array(rejected_logps)),
255257
):
256-
loss, aux = orpo_lib.dpo_loss_fn(
258+
loss_output = orpo_lib.dpo_loss_fn(
257259
model,
258260
train_example,
259261
algorithm="orpo",
260262
lambda_orpo=0.1,
261263
label_smoothing=0,
262264
)
265+
loss = loss_output.primary_loss.compute()
266+
aux = loss_output.aux_metrics
267+
263268
# Loss should be a scalar and finite
264269
self.assertEqual(loss.shape, ())
265270
self.assertTrue(jnp.isfinite(loss))
@@ -274,8 +279,8 @@ def test_orpo_loss_fn(self):
274279
self.assertIn("odds_ratio", aux)
275280

276281
# Check that accuracy is between 0 and 1
277-
self.assertGreaterEqual(aux["rewards/accuracy"], 0.0)
278-
self.assertLessEqual(aux["rewards/accuracy"], 1.0)
282+
self.assertGreaterEqual(aux["rewards/accuracy"].compute(), 0.0)
283+
self.assertLessEqual(aux["rewards/accuracy"].compute(), 1.0)
279284

280285
def test_orpo_prepare_inputs_for_strings(self):
281286
tokenizer = tc.MockVocab()

tests/sft/peft_trainer_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tunix.sft import hooks
3535
from tunix.sft import peft_trainer
3636
from tunix.sft import profiler
37+
from tunix.sft import utils
3738
from tunix.tests import test_common as tc
3839
from tunix.utils import compat
3940

@@ -634,7 +635,67 @@ def _post_process_eval_step(self, aux):
634635
self.assertEqual(train_invoke, {'foo': 2, 'bar': 4})
635636
self.assertEqual(eval_invoke, {'foo': 1, 'bar': 16})
636637

638+
def test_loss_output_format(self):
639+
def custom_loss_fn(
640+
model: nnx.Module,
641+
input_tokens: jax.Array,
642+
input_mask: jax.Array,
643+
positions: jax.Array,
644+
attention_mask: jax.Array,
645+
images: jax.Array | None = None,
646+
) -> utils.LossOutput:
647+
del model, input_tokens, input_mask, positions, attention_mask, images
648+
return utils.LossOutput(
649+
primary_loss=utils.WeightedMetric(
650+
jnp.array(2.0, dtype=jnp.float32),
651+
jnp.array(2.0, dtype=jnp.float32),
652+
),
653+
aux_metrics={
654+
'foo': utils.WeightedMetric(
655+
jnp.array(10.0, dtype=jnp.float32),
656+
jnp.array(5.0, dtype=jnp.float32),
657+
),
658+
'bar': utils.WeightedMetric(
659+
jnp.array(6.0, dtype=jnp.float32),
660+
jnp.array(2.0, dtype=jnp.float32),
661+
),
662+
},
663+
)
664+
665+
train_invoke = {'foo': 0.0, 'bar': 0.0}
666+
eval_invoke = {'foo': 0.0, 'bar': 0.0}
667+
668+
class CustomTrainer(peft_trainer.PeftTrainer):
669+
670+
def _post_process_train_step(self, aux):
671+
train_invoke['foo'] += aux['foo']
672+
train_invoke['bar'] += aux['bar']
673+
674+
def _post_process_eval_step(self, aux):
675+
eval_invoke['foo'] += aux['foo']
676+
eval_invoke['bar'] += aux['bar']
677+
678+
config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
679+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
680+
681+
trainer = CustomTrainer(model, optax.sgd(1e-3), config)
682+
trainer = trainer.with_gen_model_input_fn(
683+
dummy_gen_model_input_fn
684+
).with_loss_fn(
685+
custom_loss_fn
686+
) # Note: has_aux=False is default but LossOutput returns aux natively
687+
688+
trainer.train(self.train_ds, self.eval_ds)
689+
# The dataset provides 2 training steps.
690+
# foo = 10.0 / 5.0 = 2.0 per step.
691+
# bar = 6.0 / 2.0 = 3.0 per step.
692+
self.assertEqual(train_invoke, {'foo': 4.0, 'bar': 6.0})
693+
694+
# Since eval_ds is length 2, it evaluates at step 2.
695+
self.assertEqual(eval_invoke, {'foo': 8.0, 'bar': 12.0})
696+
637697
def test_injected_params(self):
698+
638699
config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
639700
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
640701

0 commit comments

Comments
 (0)