Skip to content

Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids#1523

Open
colincai-mc wants to merge 1 commit into
google:mainfrom
colincai-mc:upstream-pr
Open

Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids#1523
colincai-mc wants to merge 1 commit into
google:mainfrom
colincai-mc:upstream-pr

Conversation

@colincai-mc
Copy link
Copy Markdown

@colincai-mc colincai-mc commented May 15, 2026

Bug fixes

tunix/rl/agentic/agentic_rl_learner.py — eval deduplication
The eval condition train_steps % eval_every_n_steps == 0 evaluates to True for every micro-iteration within a grad_accum_steps window at a step boundary, causing the eval dataset to be replayed that many times per boundary. A _last_eval_train_step guard skips duplicate evals.

Features

Token-level truncated importance-sampling (TIS) correction (sampler_is='token')
In multi-turn agentic rollouts the rollout engine and trainer recompute the same sequence with slightly different numerical paths (temperature scaling, attention kernel, padding). The resulting log-probability drift accumulates across turns and acts as unintended noise on the importance ratio. With sampler_is='token', the trainer applies a per-token weight min(exp(trainer_logp − rollout_logp), threshold) before aggregating the policy gradient loss. This keeps the effective IS ratio grounded at the trainer's start-of-step distribution rather than the rollout engine's.

sampler_is_weights field in TrainExample + aggregate_loss application
The per-token weights computed in agentic_grpo_learner are attached to TrainExample.sampler_is_weights and applied in grpo_loss_fn before loss aggregation, so they affect the gradient through loss magnitude without introducing a stop-gradient bias on the ratio.

Sampler-trainer prob_diff and pearson metrics
Logged every training step. prob_diff is the mean absolute probability difference |softmax(rollout_logit) − softmax(trainer_logit)| over completion tokens; pearson is the per-batch correlation between rollout and trainer log-probabilities. These serve as numerical alignment health checks for the two compute paths.

Qwen3 model: segment_ids for flash attention (splash kernel)
The splash attention kernel requires per-position segment IDs to correctly compute the causal mask for left-padded sequences. Without this, padding tokens on the left contaminate real-token attention outputs during trainer recompute. This PR plumbs segment_ids derived from the non-pad mask through the Qwen3 forward pass.

vllm_sampler EOS token fix
Prevents an off-by-one in the completion mask when the sampler returns the EOS token at the boundary.

Trajectory collection: sampler logprobs in output
The trajectory_collect_engine now includes the rollout engine's per-token logprobs in the collected trajectory so the GRPO learner can apply TIS without a second forward pass through the rollout engine.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 15, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@colincai-mc colincai-mc changed the title Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids, PeftTrainer fix Agentic GRPO: TIS correction, eval dedup, flash-attn segment_ids May 15, 2026
segment_pos: jaxtyping.Array,
cache: LayerCache | None,
attn_mask: jaxtyping.Array | None,
segment_ids: jaxtyping.Array | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used for sequence packing right? right now we haven't enabled that yet since the seq packing support is still WIP

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, here it's not for packing, but needed for per-batch padding masks. So the splash attention kernel only accepts a static causal mask. The per-batch dynamic pad positions can only flow in through SegmentIds. Without this, left-padded prompts leak attention from real queries onto pad keys, making softmax garbage that shifts the logits. We use segment_id=0 for pad and segment_id=1 for real tokens to avoid it


# vLLM is configured with `include_stop_str_in_output=True`, so
# assistant_tokens from vllm DO include the trailing eos (`<|im_end|>`
# for Qwen3). But the chat template separates messages with `\n` after
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chat template refers to the parser here? shall we consider removing \n there instead of manually fixing it here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in new commit

Comment thread tunix/rl/rl_cluster.py Outdated
# the next train_step's `mul` op crashed with memory_space mismatch.
# For sampler-trainer diff metric (single-iteration GRPO), the
# "anchor" state equals the current actor state anyway, so we just
# compute directly from the actor model.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when there're multiple model updates (e.g. mini_batch < global_batch), wouldn't we want to use the actor model weights at the very beginning of the global_step?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in new commit

…ention

Adds correctness fixes and observability for multi-turn agentic GRPO
runs, plus flash-attention support for prompts with left-padding.

Key changes:
  - PeftTrainer: skip the ``max_steps`` early-exit when the trainer is
    managed externally so the agentic outer loop is not interrupted
    mid-step.
  - agentic_rl_learner: dedupe eval at step boundaries (without this,
    ``train_steps % eval_every_n_steps == 0`` fires once per micro-step
    when the optimizer is wrapped in MultiSteps, replaying the held-out
    rollout ``grad_accum_steps`` times for the same step). Adds a
    per-global-step diagnostic log line for run health visibility.
  - agentic GRPO learner: optional truncated importance-sampling (TIS)
    correction (``sampler_is='token'``). When enabled the loss uses the
    trainer's start-of-step recomputed logp as ``old_per_token_logps``
    and multiplies each per-token pg-loss term by a detached weight
    ``min(exp(trainer_logp - sampler_logp), threshold)``. Dampens
    positions where the trainer's recomputed probability disagrees with
    the rollout sampler, reducing variance from sampler-trainer drift
    in multi-turn rollouts. Logs sampler-trainer logp_diff,
    prob_diff, and Pearson correlation every step.
  - TrainExample: ``sampler_is_weights`` field threaded through the
    GRPO loss before aggregation.
  - rl/common.process_ids: build the attention/RoPE-position mask from
    a non-pad mask rather than the caller-supplied
    ``completion_mask`` (which in multi-turn agentic learners is the
    assistant-vs-env loss mask). Emits a 1-D per-position non-pad mask
    as ``segment_ids`` so attention kernels that can't consume the
    2-D attn mask directly still skip pad positions.
  - rl_cluster: thread the anchor (start-of-step) actor weights into
    the per-token logp recompute path so old_per_token_logps reference
    the same policy the sampler used, even with
    ``mini_batch_size < full_batch_size`` or ``num_iterations > 1``.
  - algo_core: per-token diagnostics for the GRPO/GSPO loss
    (is_ratio mean/min/max, log_ratio absolute mean, pg_loss
    unclipped/clipped means, advantage abs mean/min/max/nonzero_frac,
    sampler-IS weight mean/min).
  - Qwen3 model: thread ``segment_ids`` through Attention / DecoderLayer
    / Qwen3 so the pallas splash-attention kernel receives per-position
    segment ids and can suppress cross-segment attention. Without
    this, left-padded prompts contaminate real-token attention output.
  - vllm_sampler: keep the trailing stop token in returned token_ids
    and enable ``include_stop_str_in_output=True`` so multi-turn
    consumers can reconstruct the exact sequence the model was
    sampled on. Previously the stripped stop token caused
    trainer-side concatenation to miss the per-turn ``<|im_end|>`` and
    produce 30+ nat sampler-trainer logp diffs at turn boundaries.
  - generate.sampler: treat ``top_k=0`` (or None) as "no top-k
    filter" rather than failing the fast-path check.
  - trajectory_collect_engine: append tokens, masks, and logprobs to
    the conversation buffer in lockstep, and inject a single ``\n``
    bridge between assistant and env tokens to match the
    apply_chat_template rendering. Previously a step with env_tokens
    but no logprobs offset the logprobs array against
    conversation_tokens for every subsequent step.
  - chat_template_parser: ``message_separator`` field + ``\n`` for
    QwenChatTemplateParser so incremental per-message rendering
    concatenates to match ``apply_chat_template`` output exactly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants