4949from tunix .rl .agentic .environments import base_environment
5050from tunix .rl .agentic .environments import task_environment
5151from tunix .rl .ppo import ppo_helpers
52+ from tunix .sft import utils as sft_utils
5253from tunix .utils import trajectory_logger
5354
54-
5555TrainingInputT = agentic_rl_learner .TrainingInputT
5656RewardFn = agentic_rl_learner .RewardFn
5757MetricFn = agentic_rl_learner .MetricFn
@@ -74,8 +74,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
7474 num_iterations: Number of GRPO iterations per batch (μ in the paper).
7575 beta: KL penalty coefficient.
7676 kl_loss_mode: Method for computing the KL loss.
77- force_compute_kl: Whether to force compute KL divergence for logging
78- even when it would normally be skipped (e.g., when beta is 0.0).
77+ force_compute_kl: Whether to force compute KL divergence for logging even
78+ when it would normally be skipped (e.g., when beta is 0.0).
7979 epsilon: PPO-style clipping epsilon.
8080 epsilon_high: PPO-style clipping epsilon upper bound.
8181 loss_algo: "grpo" or "gspo-token".
@@ -251,8 +251,7 @@ def __init__(
251251 })
252252 self .rl_cluster .actor_trainer .with_tqdm_metrics_to_display ([
253253 lambda : "kl"
254- if self .algo_config .force_compute_kl
255- or self .algo_config .beta != 0.0
254+ if self .algo_config .force_compute_kl or self .algo_config .beta != 0.0
256255 else None ,
257256 ])
258257
@@ -594,9 +593,7 @@ def grpo_loss_fn(
594593 else epsilon
595594 )
596595 epsilon_c = (
597- algo_config .epsilon_c
598- if hasattr (algo_config , "epsilon_c" )
599- else 3.0
596+ algo_config .epsilon_c if hasattr (algo_config , "epsilon_c" ) else 3.0
600597 )
601598 loss_aggregation_mode = algo_config .loss_agg_mode
602599
@@ -633,7 +630,8 @@ def grpo_loss_fn(
633630
634631 seq_importance_ratio = per_token_logps - old_per_token_logps
635632 # Record KL divergence before clipping.
636- ppo_kl = ppo_helpers .masked_mean (- seq_importance_ratio , completion_mask )
633+ unreduced_ppo_kl = jnp .sum (- seq_importance_ratio * completion_mask )
634+ token_denom = completion_mask .sum ()
637635
638636 seq_importance_ratio = jnp .clip (seq_importance_ratio , max = 20.0 , min = - 20.0 )
639637
@@ -661,35 +659,38 @@ def grpo_loss_fn(
661659
662660 per_token_loss = jnp .maximum (pg_loss_1 , pg_loss_2 ).astype (jnp .float32 )
663661
664- clipped_fraction = ppo_helpers .masked_mean (
665- jnp .greater (pg_loss_2 , pg_loss_1 ), completion_mask
666- )
662+ unreduced_clip_frac = jnp .sum (jnp .greater (pg_loss_2 , pg_loss_1 ) * completion_mask )
667663
668664 # dual-clip ppo loss
669665 pg_loss_3 = - epsilon_c * adv
670666
671667 # pg_clipfrac_lower measures how often dual-clip ppo kicks in.
672668 # It kicks in when the standard clipped loss is larger than pg_loss_3
673669 # for instances with negative advantages.
674- unreduced_pg_clipfrac_lower = (
670+ per_token_pg_clipfrac_lower = (
675671 (per_token_loss > pg_loss_3 ) & (adv < 0.0 )
676672 ).astype (jnp .float32 )
677- pg_clipfrac_lower = common .aggregate_loss (
678- unreduced_pg_clipfrac_lower , completion_mask , loss_aggregation_mode
673+ unreduced_pg_clipfrac_lower = common .aggregate_loss (
674+ per_token_pg_clipfrac_lower , completion_mask , loss_aggregation_mode
679675 )
680676
681677 pg_loss_clipped_dual = jnp .minimum (pg_loss_3 , per_token_loss )
682678 per_token_loss = jnp .where (adv < 0.0 , pg_loss_clipped_dual , per_token_loss )
683- loss = common .aggregate_loss (
679+ weighted_loss = common .aggregate_loss (
684680 per_token_loss , completion_mask , loss_aggregation_mode
685681 )
682+
686683 aux = {
687- "kl" : 0.0 ,
688- "kl_loss" : 0.0 ,
689- "pg_loss" : loss ,
690- "pg_clipfrac" : clipped_fraction ,
691- "ppo_kl" : ppo_kl ,
692- "pg_clipfrac_lower" : pg_clipfrac_lower ,
684+ "kl" : sft_utils .WeightedMetric (jnp .array (0.0 ), jnp .array (1.0 )),
685+ "kl_loss" : sft_utils .WeightedMetric (jnp .array (0.0 ), jnp .array (1.0 )),
686+ "pg_loss" : weighted_loss ,
687+ "pg_clipfrac" : sft_utils .WeightedMetric (
688+ unreduced_clip_frac , token_denom , min_denom = 1.0
689+ ),
690+ "ppo_kl" : sft_utils .WeightedMetric (
691+ unreduced_ppo_kl , token_denom , min_denom = 1.0
692+ ),
693+ "pg_clipfrac_lower" : unreduced_pg_clipfrac_lower ,
693694 }
694695 # We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless
695696 # force_compute_kl is True).
@@ -699,25 +700,27 @@ def grpo_loss_fn(
699700 train_example .ref_per_token_logps ,
700701 algo_config .kl_loss_mode ,
701702 )
702- # Log mean KL.
703- aux ["kl" ] = jnp .astype (
704- (kl * completion_mask ).sum () / jnp .clip (completion_mask .sum (), min = 1 ),
705- jnp .float32 ,
706- )
707- kl_loss = common .aggregate_loss (
708- kl , completion_mask , loss_aggregation_mode
703+ unreduced_kl = jnp .astype (jnp .sum (kl * completion_mask ), jnp .float32 )
704+ aux ["kl" ] = sft_utils .WeightedMetric (
705+ unreduced_kl , token_denom , min_denom = 1.0
709706 )
707+ kl_loss = common .aggregate_loss (kl , completion_mask , loss_aggregation_mode )
710708 aux ["kl_loss" ] = kl_loss
711709 if beta is not None and beta != 0.0 :
712- loss = loss + beta * kl_loss
710+ weighted_loss = sft_utils .WeightedMetric (
711+ weighted_loss .unreduced_sum + beta * kl_loss .unreduced_sum ,
712+ weighted_loss .denominator ,
713+ eps = weighted_loss .eps ,
714+ min_denom = weighted_loss .min_denom ,
715+ )
713716
714717 token_entropy = ppo_helpers .compute_entropy_from_logits (logits )
715718 entropy_loss = common .aggregate_loss (
716719 token_entropy , completion_mask , loss_aggregation_mode
717720 )
718721 aux ["entropy" ] = entropy_loss
719722
720- return loss , aux
723+ return sft_utils . LossOutput ( primary_loss = weighted_loss , aux_metrics = aux )
721724
722725
723726@function_registry .register_advantage_estimator ("agentic_grpo" )
0 commit comments