Skip to content

Commit 4870afe

Browse files
AlienKevinBabyChouSrahmeda14960
authored
RL Loss Improvements (#2327)
Changes to RL loss, zero-variance prompt filtering, length penalty, and curriculum. --------- Co-authored-by: Christopher Chou <49086305+BabyChouSr@users.noreply.github.com> Co-authored-by: Ahmed Ahmed <ahmedah@stanford.edu>
1 parent de3704f commit 4870afe

2 files changed

Lines changed: 38 additions & 7 deletions

File tree

lib/marin/src/marin/rl/curriculum.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ class SamplingParams:
8585
max_output_tokens: int = 512
8686
stop_tokens: list[int] | None = None
8787

88+
def __post_init__(self):
89+
if self.temperature < 1e-4:
90+
logger.warning(
91+
"SamplingParams.temperature is very low (%f). Greedy decoding is generally "
92+
"not useful for RL training as it limits exploration.",
93+
self.temperature,
94+
)
95+
if self.top_k == 1:
96+
logger.warning("SamplingParams.top_k is 1. Greedy decoding is generally not useful for RL training.")
97+
8898

8999
@dataclass
90100
class LessonConfig:

lib/marin/src/marin/rl/rl_losses.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,18 +207,13 @@ def compute_ppo_loss_objective(
207207
loss_objective = jnp.minimum(non_clipped_objective, clipped_objective)
208208
if trainer_inference_importance_sampling_ratio is not None:
209209
loss_objective = trainer_inference_importance_sampling_ratio * loss_objective
210-
# Mean over response tokens per batch
211-
# loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
212210

213211
if response_truncated_array is not None:
214212
batch_size, _ = loss_objective.shape
215213
loss_objective = loss_objective * (1 - response_truncated_array.reshape(batch_size, 1))
216214

217-
# Dr GRPO loss, token-level loss
218-
# loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / max_output_tokens)
219-
220-
# more like DAPO loss
221-
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))
215+
# Default to DAPO loss (matches original active behavior)
216+
loss = compute_dapo_loss(loss_objective, loss_masks)
222217

223218
per_batch_loss = jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1)
224219
metadata = {
@@ -228,6 +223,32 @@ def compute_ppo_loss_objective(
228223
return loss, metadata
229224

230225

226+
def compute_ppo_loss(
227+
loss_objective: jax.Array,
228+
loss_masks: jax.Array,
229+
) -> jax.Array:
230+
"""Compute PPO loss (per-example normalization)."""
231+
return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
232+
233+
234+
def compute_dapo_loss(
235+
loss_objective: jax.Array,
236+
loss_masks: jax.Array,
237+
) -> jax.Array:
238+
"""Compute DAPO-like loss (per-example normalization)."""
239+
# Use per-example normalization (averaging the per-example means)
240+
return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
241+
242+
243+
def compute_grpo_loss(
244+
loss_objective: jax.Array,
245+
loss_masks: jax.Array,
246+
max_output_tokens: int,
247+
) -> jax.Array:
248+
"""Compute GRPO loss (token-level loss)."""
249+
return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / max_output_tokens)
250+
251+
231252
def importance_sampling_ratio(
232253
current_logprobs: jax.Array,
233254
policy_logprobs_array: jax.Array,

0 commit comments

Comments
 (0)