Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,15 @@ def __call__(self, inputs, positions, cache, attention_mask):
policy_loss_fn = function_registry.get_policy_loss_fn(
algo_config.policy_loss_fn
)
loss, aux = policy_loss_fn(
loss_output = policy_loss_fn(
model=MockModel(rngs=nnx.Rngs(0)),
train_example=train_example,
algo_config=algo_config,
pad_id=0,
eos_id=2,
)
loss = loss_output.primary_loss.compute()
aux = loss_output.aux_metrics
chex.assert_shape(loss, ())
self.assertIn("kl", aux)

Expand Down Expand Up @@ -535,7 +537,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
policy_loss_fn = function_registry.get_policy_loss_fn(config.policy_loss_fn)

model = MockModel(rngs=nnx.Rngs(0))
loss, _ = policy_loss_fn(
loss_output = policy_loss_fn(
model=model,
train_example=train_example,
algo_config=config,
Expand Down Expand Up @@ -567,6 +569,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
else:
expected_loss = float(jnp.mean(per_sequence_loss))

loss = loss_output.primary_loss.compute()
np.testing.assert_allclose(loss, expected_loss, rtol=1e-6, atol=1e-6)

def test_process_results_extracts_assistant_text(self):
Expand Down
17 changes: 12 additions & 5 deletions tests/rl/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
jax.config.update("jax_threefry_partitionable", False)


def _compute_loss(*args, **kwargs):
out = getattr(common, "aggregate_loss")(*args, **kwargs)
return out.compute()


class CommonTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -446,7 +451,9 @@ def test_pad_to_length(self):
expected_loss=(0.1 + 0.2) / 4.0 / 1.0,
),
dict(
testcase_name="sequence_mean_token_sum_norm_partial_zero_mask_default",
testcase_name=(
"sequence_mean_token_sum_norm_partial_zero_mask_default"
),
loss_agg_mode="sequence-mean-token-sum-norm",
per_token_loss_list=[[0.1, 0.2], [0.3, 0.4]],
completion_mask_list=[[1, 1], [0, 0]],
Expand Down Expand Up @@ -496,7 +503,7 @@ def test_aggregate_loss_values(
):
per_token_loss = jnp.array(per_token_loss_list)
completion_mask = jnp.array(completion_mask_list)
actual_loss = common.aggregate_loss(
actual_loss = _compute_loss(
per_token_loss, completion_mask, loss_agg_mode, **kwargs
)
np.testing.assert_allclose(actual_loss, expected_loss, rtol=1e-6, atol=1e-6)
Expand All @@ -505,7 +512,7 @@ def test_invalid_mode(self):
with self.assertRaisesRegex(
ValueError, "Unsupported loss aggregation mode"
):
common.aggregate_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")
_compute_loss(jnp.ones((2, 2)), jnp.ones((2, 2)), "invalid-mode")

@parameterized.named_parameters(
dict(
Expand Down Expand Up @@ -541,7 +548,7 @@ def test_invalid_mode(self):
)
def test_invalid_norm(self, norm_val, loss_agg_mode):
with self.assertRaisesRegex(ValueError, "Invalid 'norm' value"):
common.aggregate_loss(
_compute_loss(
jnp.ones((2, 2)),
jnp.ones((2, 2)),
loss_agg_mode,
Expand All @@ -567,7 +574,7 @@ def test_aggregate_loss_bf16(self):
per_token_loss = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16)
completion_mask = jnp.array([1, 1, 0], dtype=jnp.int32)

loss = common.aggregate_loss(
loss = _compute_loss(
per_token_loss, completion_mask, loss_agg_mode="token-mean"
)
self.assertEqual(loss.dtype, jnp.float32)
Expand Down
16 changes: 11 additions & 5 deletions tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,23 @@ def test_diff_loss(self):
)

# Call DAPO loss function
dapo_loss, dapo_aux = dapo_loss_fn_impl(
dapo_loss_output = dapo_loss_fn_impl(
model, train_example, dapo_config, pad_id, eos_id
)
dapo_loss = dapo_loss_output.primary_loss.compute()
dapo_aux = dapo_loss_output.aux_metrics

# Call GRPO loss function
grpo_loss, grpo_aux = grpo_loss_fn_impl(
grpo_loss_output = grpo_loss_fn_impl(
model, train_example, grpo_config, pad_id, eos_id
)
grpo_loss = grpo_loss_output.primary_loss.compute()
grpo_aux = grpo_loss_output.aux_metrics

# Assert that the loss values are different
self.assertNotEqual(
dapo_loss.item(),
grpo_loss.item(),
dapo_loss,
grpo_loss,
msg=(
"DAPO and GRPO loss values should be different for the same input"
" due to different loss aggregation logics."
Expand All @@ -111,7 +115,9 @@ def test_diff_loss(self):

self.assertIn("kl", dapo_aux)
self.assertIn("kl", grpo_aux)
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.
self.assertEqual(
dapo_aux["kl"].compute(), 0.0
) # DAPO does not have KL term.


class TestDAPOConfigPostInit(parameterized.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion tests/rl/grpo/drgrpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ def test_drgrpo_loss_fn(self):
)

# Call DrGRPO loss function
drgrpo_loss, drgrpo_aux = drgrpo_loss_fn_impl(
drgrpo_loss_output = drgrpo_loss_fn_impl(
model, train_example, drgrpo_config, pad_id, eos_id
)
drgrpo_loss = drgrpo_loss_output.primary_loss.compute()
drgrpo_aux = drgrpo_loss_output.aux_metrics

self.assertIn("kl", drgrpo_aux)
self.assertTrue(jnp.isfinite(drgrpo_loss).all())
Expand Down
6 changes: 4 additions & 2 deletions tests/sft/dpo/dpo_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,16 @@ def test_dpo_loss_fn(self):
with mock.patch.object(
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
):
loss, _ = dpo_lib.dpo_loss_fn(
loss_output = dpo_lib.dpo_loss_fn(
model, train_example, beta=0.1, label_smoothing=0
)
loss = loss_output.primary_loss.compute()
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)

loss, _ = dpo_lib.dpo_loss_fn(
loss_output = dpo_lib.dpo_loss_fn(
model, train_example, beta=0.1, label_smoothing=0.3
)
loss = loss_output.primary_loss.compute()
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)

def test_dpo_prepare_inputs_for_strings(self):
Expand Down
5 changes: 4 additions & 1 deletion tests/sft/dpo/orpo_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,16 @@ def test_orpo_loss_fn(self):
"compute_logps",
return_value=(jnp.array(chosen_logps), jnp.array(rejected_logps)),
):
loss, aux = orpo_lib.dpo_loss_fn(
loss_output = orpo_lib.dpo_loss_fn(
model,
train_example,
algorithm="orpo",
lambda_orpo=0.1,
label_smoothing=0,
)
loss = loss_output.primary_loss.compute()
aux = loss_output.aux_metrics

# Loss should be a scalar and finite
self.assertEqual(loss.shape, ())
self.assertTrue(jnp.isfinite(loss))
Expand Down
61 changes: 61 additions & 0 deletions tests/sft/peft_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tunix.sft import hooks
from tunix.sft import peft_trainer
from tunix.sft import profiler
from tunix.sft import utils
from tunix.tests import test_common as tc
from tunix.utils import compat

Expand Down Expand Up @@ -634,7 +635,67 @@ def _post_process_eval_step(self, aux):
self.assertEqual(train_invoke, {'foo': 2, 'bar': 4})
self.assertEqual(eval_invoke, {'foo': 1, 'bar': 16})

def test_loss_output_format(self):
def custom_loss_fn(
model: nnx.Module,
input_tokens: jax.Array,
input_mask: jax.Array,
positions: jax.Array,
attention_mask: jax.Array,
images: jax.Array | None = None,
) -> utils.LossOutput:
del model, input_tokens, input_mask, positions, attention_mask, images
return utils.LossOutput(
primary_loss=utils.WeightedMetric(
jnp.array(2.0, dtype=jnp.float32),
jnp.array(2.0, dtype=jnp.float32),
),
aux_metrics={
'foo': utils.WeightedMetric(
jnp.array(10.0, dtype=jnp.float32),
jnp.array(5.0, dtype=jnp.float32),
),
'bar': utils.WeightedMetric(
jnp.array(6.0, dtype=jnp.float32),
jnp.array(2.0, dtype=jnp.float32),
),
},
)

train_invoke = {'foo': 0.0, 'bar': 0.0}
eval_invoke = {'foo': 0.0, 'bar': 0.0}

class CustomTrainer(peft_trainer.PeftTrainer):

def _post_process_train_step(self, aux):
train_invoke['foo'] += aux['foo']
train_invoke['bar'] += aux['bar']

def _post_process_eval_step(self, aux):
eval_invoke['foo'] += aux['foo']
eval_invoke['bar'] += aux['bar']

config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))

trainer = CustomTrainer(model, optax.sgd(1e-3), config)
trainer = trainer.with_gen_model_input_fn(
dummy_gen_model_input_fn
).with_loss_fn(
custom_loss_fn
) # Note: has_aux=False is default but LossOutput returns aux natively

trainer.train(self.train_ds, self.eval_ds)
# The dataset provides 2 training steps.
# foo = 10.0 / 5.0 = 2.0 per step.
# bar = 6.0 / 2.0 = 3.0 per step.
self.assertEqual(train_invoke, {'foo': 4.0, 'bar': 6.0})

# Since eval_ds is length 2, it evaluates at step 2.
self.assertEqual(eval_invoke, {'foo': 8.0, 'bar': 12.0})

def test_injected_params(self):

config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100)
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))

Expand Down
65 changes: 34 additions & 31 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.environments import task_environment
from tunix.rl.ppo import ppo_helpers
from tunix.sft import utils as sft_utils
from tunix.utils import trajectory_logger


TrainingInputT = agentic_rl_learner.TrainingInputT
RewardFn = agentic_rl_learner.RewardFn
MetricFn = agentic_rl_learner.MetricFn
Expand All @@ -74,8 +74,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
num_iterations: Number of GRPO iterations per batch (μ in the paper).
beta: KL penalty coefficient.
kl_loss_mode: Method for computing the KL loss.
force_compute_kl: Whether to force compute KL divergence for logging
even when it would normally be skipped (e.g., when beta is 0.0).
force_compute_kl: Whether to force compute KL divergence for logging even
when it would normally be skipped (e.g., when beta is 0.0).
epsilon: PPO-style clipping epsilon.
epsilon_high: PPO-style clipping epsilon upper bound.
loss_algo: "grpo" or "gspo-token".
Expand Down Expand Up @@ -251,8 +251,7 @@ def __init__(
})
self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([
lambda: "kl"
if self.algo_config.force_compute_kl
or self.algo_config.beta != 0.0
if self.algo_config.force_compute_kl or self.algo_config.beta != 0.0
else None,
])

Expand Down Expand Up @@ -594,9 +593,7 @@ def grpo_loss_fn(
else epsilon
)
epsilon_c = (
algo_config.epsilon_c
if hasattr(algo_config, "epsilon_c")
else 3.0
algo_config.epsilon_c if hasattr(algo_config, "epsilon_c") else 3.0
)
loss_aggregation_mode = algo_config.loss_agg_mode

Expand Down Expand Up @@ -633,7 +630,8 @@ def grpo_loss_fn(

seq_importance_ratio = per_token_logps - old_per_token_logps
# Record KL divergence before clipping.
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)
unreduced_ppo_kl = jnp.sum(-seq_importance_ratio * completion_mask)
token_denom = completion_mask.sum()

seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)

Expand Down Expand Up @@ -661,35 +659,38 @@ def grpo_loss_fn(

per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)

clipped_fraction = ppo_helpers.masked_mean(
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
)
unreduced_clip_frac = jnp.sum(jnp.greater(pg_loss_2, pg_loss_1) * completion_mask)

# dual-clip ppo loss
pg_loss_3 = -epsilon_c * adv

# pg_clipfrac_lower measures how often dual-clip ppo kicks in.
# It kicks in when the standard clipped loss is larger than pg_loss_3
# for instances with negative advantages.
unreduced_pg_clipfrac_lower = (
per_token_pg_clipfrac_lower = (
(per_token_loss > pg_loss_3) & (adv < 0.0)
).astype(jnp.float32)
pg_clipfrac_lower = common.aggregate_loss(
unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
unreduced_pg_clipfrac_lower = common.aggregate_loss(
per_token_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
)

pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss)
per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss)
loss = common.aggregate_loss(
weighted_loss = common.aggregate_loss(
per_token_loss, completion_mask, loss_aggregation_mode
)

aux = {
"kl": 0.0,
"kl_loss": 0.0,
"pg_loss": loss,
"pg_clipfrac": clipped_fraction,
"ppo_kl": ppo_kl,
"pg_clipfrac_lower": pg_clipfrac_lower,
"kl": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)),
"kl_loss": sft_utils.WeightedMetric(jnp.array(0.0), jnp.array(1.0)),
"pg_loss": weighted_loss,
"pg_clipfrac": sft_utils.WeightedMetric(
unreduced_clip_frac, token_denom, min_denom=1.0
),
"ppo_kl": sft_utils.WeightedMetric(
unreduced_ppo_kl, token_denom, min_denom=1.0
),
"pg_clipfrac_lower": unreduced_pg_clipfrac_lower,
}
# We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless
# force_compute_kl is True).
Expand All @@ -699,25 +700,27 @@ def grpo_loss_fn(
train_example.ref_per_token_logps,
algo_config.kl_loss_mode,
)
# Log mean KL.
aux["kl"] = jnp.astype(
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
jnp.float32,
)
kl_loss = common.aggregate_loss(
kl, completion_mask, loss_aggregation_mode
unreduced_kl = jnp.astype(jnp.sum(kl * completion_mask), jnp.float32)
aux["kl"] = sft_utils.WeightedMetric(
unreduced_kl, token_denom, min_denom=1.0
)
kl_loss = common.aggregate_loss(kl, completion_mask, loss_aggregation_mode)
aux["kl_loss"] = kl_loss
if beta is not None and beta != 0.0:
loss = loss + beta * kl_loss
weighted_loss = sft_utils.WeightedMetric(
weighted_loss.unreduced_sum + beta * kl_loss.unreduced_sum,
weighted_loss.denominator,
eps=weighted_loss.eps,
min_denom=weighted_loss.min_denom,
)

token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
entropy_loss = common.aggregate_loss(
token_entropy, completion_mask, loss_aggregation_mode
)
aux["entropy"] = entropy_loss

return loss, aux
return sft_utils.LossOutput(primary_loss=weighted_loss, aux_metrics=aux)


@function_registry.register_advantage_estimator("agentic_grpo")
Expand Down
Loading
Loading