Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids#1523
Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids#1523colincai-mc wants to merge 1 commit into
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
| segment_pos: jaxtyping.Array, | ||
| cache: LayerCache | None, | ||
| attn_mask: jaxtyping.Array | None, | ||
| segment_ids: jaxtyping.Array | None = None, |
There was a problem hiding this comment.
this is used for sequence packing right? right now we haven't enabled that yet since the seq packing support is still WIP
There was a problem hiding this comment.
yeah, here it's not for packing, but needed for per-batch padding masks. So the splash attention kernel only accepts a static causal mask. The per-batch dynamic pad positions can only flow in through SegmentIds. Without this, left-padded prompts leak attention from real queries onto pad keys, making softmax garbage that shifts the logits. We use segment_id=0 for pad and segment_id=1 for real tokens to avoid it
|
|
||
| # vLLM is configured with `include_stop_str_in_output=True`, so | ||
| # assistant_tokens from vllm DO include the trailing eos (`<|im_end|>` | ||
| # for Qwen3). But the chat template separates messages with `\n` after |
There was a problem hiding this comment.
chat template refers to the parser here? shall we consider removing \n there instead of manually fixing it here?
| # the next train_step's `mul` op crashed with memory_space mismatch. | ||
| # For sampler-trainer diff metric (single-iteration GRPO), the | ||
| # "anchor" state equals the current actor state anyway, so we just | ||
| # compute directly from the actor model. |
There was a problem hiding this comment.
when there're multiple model updates (e.g. mini_batch < global_batch), wouldn't we want to use the actor model weights at the very beginning of the global_step?
…ention
Adds correctness fixes and observability for multi-turn agentic GRPO
runs, plus flash-attention support for prompts with left-padding.
Key changes:
- PeftTrainer: skip the ``max_steps`` early-exit when the trainer is
managed externally so the agentic outer loop is not interrupted
mid-step.
- agentic_rl_learner: dedupe eval at step boundaries (without this,
``train_steps % eval_every_n_steps == 0`` fires once per micro-step
when the optimizer is wrapped in MultiSteps, replaying the held-out
rollout ``grad_accum_steps`` times for the same step). Adds a
per-global-step diagnostic log line for run health visibility.
- agentic GRPO learner: optional truncated importance-sampling (TIS)
correction (``sampler_is='token'``). When enabled the loss uses the
trainer's start-of-step recomputed logp as ``old_per_token_logps``
and multiplies each per-token pg-loss term by a detached weight
``min(exp(trainer_logp - sampler_logp), threshold)``. Dampens
positions where the trainer's recomputed probability disagrees with
the rollout sampler, reducing variance from sampler-trainer drift
in multi-turn rollouts. Logs sampler-trainer logp_diff,
prob_diff, and Pearson correlation every step.
- TrainExample: ``sampler_is_weights`` field threaded through the
GRPO loss before aggregation.
- rl/common.process_ids: build the attention/RoPE-position mask from
a non-pad mask rather than the caller-supplied
``completion_mask`` (which in multi-turn agentic learners is the
assistant-vs-env loss mask). Emits a 1-D per-position non-pad mask
as ``segment_ids`` so attention kernels that can't consume the
2-D attn mask directly still skip pad positions.
- rl_cluster: thread the anchor (start-of-step) actor weights into
the per-token logp recompute path so old_per_token_logps reference
the same policy the sampler used, even with
``mini_batch_size < full_batch_size`` or ``num_iterations > 1``.
- algo_core: per-token diagnostics for the GRPO/GSPO loss
(is_ratio mean/min/max, log_ratio absolute mean, pg_loss
unclipped/clipped means, advantage abs mean/min/max/nonzero_frac,
sampler-IS weight mean/min).
- Qwen3 model: thread ``segment_ids`` through Attention / DecoderLayer
/ Qwen3 so the pallas splash-attention kernel receives per-position
segment ids and can suppress cross-segment attention. Without
this, left-padded prompts contaminate real-token attention output.
- vllm_sampler: keep the trailing stop token in returned token_ids
and enable ``include_stop_str_in_output=True`` so multi-turn
consumers can reconstruct the exact sequence the model was
sampled on. Previously the stripped stop token caused
trainer-side concatenation to miss the per-turn ``<|im_end|>`` and
produce 30+ nat sampler-trainer logp diffs at turn boundaries.
- generate.sampler: treat ``top_k=0`` (or None) as "no top-k
filter" rather than failing the fast-path check.
- trajectory_collect_engine: append tokens, masks, and logprobs to
the conversation buffer in lockstep, and inject a single ``\n``
bridge between assistant and env tokens to match the
apply_chat_template rendering. Previously a step with env_tokens
but no logprobs offset the logprobs array against
conversation_tokens for every subsequent step.
- chat_template_parser: ``message_separator`` field + ``\n`` for
QwenChatTemplateParser so incremental per-message rendering
concatenates to match ``apply_chat_template`` output exactly.
Bug fixes
tunix/rl/agentic/agentic_rl_learner.py— eval deduplicationThe eval condition
train_steps % eval_every_n_steps == 0evaluates toTruefor every micro-iteration within agrad_accum_stepswindow at a step boundary, causing the eval dataset to be replayed that many times per boundary. A_last_eval_train_stepguard skips duplicate evals.Features
Token-level truncated importance-sampling (TIS) correction (
sampler_is='token')In multi-turn agentic rollouts the rollout engine and trainer recompute the same sequence with slightly different numerical paths (temperature scaling, attention kernel, padding). The resulting log-probability drift accumulates across turns and acts as unintended noise on the importance ratio. With
sampler_is='token', the trainer applies a per-token weightmin(exp(trainer_logp − rollout_logp), threshold)before aggregating the policy gradient loss. This keeps the effective IS ratio grounded at the trainer's start-of-step distribution rather than the rollout engine's.sampler_is_weightsfield inTrainExample+aggregate_lossapplicationThe per-token weights computed in
agentic_grpo_learnerare attached toTrainExample.sampler_is_weightsand applied ingrpo_loss_fnbefore loss aggregation, so they affect the gradient through loss magnitude without introducing a stop-gradient bias on the ratio.Sampler-trainer
prob_diffandpearsonmetricsLogged every training step.
prob_diffis the mean absolute probability difference|softmax(rollout_logit) − softmax(trainer_logit)|over completion tokens;pearsonis the per-batch correlation between rollout and trainer log-probabilities. These serve as numerical alignment health checks for the two compute paths.Qwen3 model:
segment_idsfor flash attention (splash kernel)The splash attention kernel requires per-position segment IDs to correctly compute the causal mask for left-padded sequences. Without this, padding tokens on the left contaminate real-token attention outputs during trainer recompute. This PR plumbs
segment_idsderived from the non-pad mask through the Qwen3 forward pass.vllm_samplerEOS token fixPrevents an off-by-one in the completion mask when the sampler returns the EOS token at the boundary.
Trajectory collection: sampler logprobs in output
The
trajectory_collect_enginenow includes the rollout engine's per-token logprobs in the collected trajectory so the GRPO learner can apply TIS without a second forward pass through the rollout engine.