@@ -207,18 +207,13 @@ def compute_ppo_loss_objective(
207207 loss_objective = jnp .minimum (non_clipped_objective , clipped_objective )
208208 if trainer_inference_importance_sampling_ratio is not None :
209209 loss_objective = trainer_inference_importance_sampling_ratio * loss_objective
210- # Mean over response tokens per batch
211- # loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
212210
213211 if response_truncated_array is not None :
214212 batch_size , _ = loss_objective .shape
215213 loss_objective = loss_objective * (1 - response_truncated_array .reshape (batch_size , 1 ))
216214
217- # Dr GRPO loss, token-level loss
218- # loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / max_output_tokens)
219-
220- # more like DAPO loss
221- loss = - 1 * jnp .mean (jnp .sum (loss_objective * loss_masks , axis = 1 ) / jnp .sum (loss_masks ))
215+ # Default to DAPO loss (matches original active behavior)
216+ loss = compute_dapo_loss (loss_objective , loss_masks )
222217
223218 per_batch_loss = jnp .sum (loss_objective * loss_masks , axis = 1 ) / jnp .sum (loss_masks , axis = 1 )
224219 metadata = {
@@ -228,6 +223,32 @@ def compute_ppo_loss_objective(
228223 return loss , metadata
229224
230225
226+ def compute_ppo_loss (
227+ loss_objective : jax .Array ,
228+ loss_masks : jax .Array ,
229+ ) -> jax .Array :
230+ """Compute PPO loss (per-example normalization)."""
231+ return - 1 * jnp .mean (jnp .sum (loss_objective * loss_masks , axis = 1 ) / jnp .sum (loss_masks , axis = 1 ))
232+
233+
234+ def compute_dapo_loss (
235+ loss_objective : jax .Array ,
236+ loss_masks : jax .Array ,
237+ ) -> jax .Array :
238+ """Compute DAPO-like loss (per-example normalization)."""
239+ # Use per-example normalization (averaging the per-example means)
240+ return - 1 * jnp .mean (jnp .sum (loss_objective * loss_masks , axis = 1 ) / jnp .sum (loss_masks , axis = 1 ))
241+
242+
243+ def compute_grpo_loss (
244+ loss_objective : jax .Array ,
245+ loss_masks : jax .Array ,
246+ max_output_tokens : int ,
247+ ) -> jax .Array :
248+ """Compute GRPO loss (token-level loss)."""
249+ return - 1 * jnp .mean (jnp .sum (loss_objective * loss_masks , axis = 1 ) / max_output_tokens )
250+
251+
231252def importance_sampling_ratio (
232253 current_logprobs : jax .Array ,
233254 policy_logprobs_array : jax .Array ,
0 commit comments