Skip to content

Commit dea9da5

Browse files
colincai-mcclaude
andcommitted
Agentic GRPO improvements: sampler-IS correction, eval fix, flash attn
Key changes: - Fix PeftTrainer early exit when is_managed_externally=True (prevents spurious max_steps-triggered break inside the externally-managed agentic training loop) - Fix eval deduplication in agentic_rl_learner: eval at a step boundary fired grad_accum_steps times instead of once; guard with _last_eval_train_step to skip repeat evals within the same train_step - Add token-level truncated importance-sampling (TIS) correction in agentic GRPO to account for sampler-trainer log-probability drift in multi-turn rollouts (sampler_is='token', configurable threshold) - Add sampler_is_weights field to TrainExample; apply in GRPO loss before aggregation - Log sampler-trainer prob_diff and pearson correlation every training step for diagnosing numerical alignment between rollout and trainer - Qwen3 model: thread segment_ids through flash-attention (splash kernel) forward pass so left-padded prompts do not contaminate attention output - vllm_sampler: fix EOS token handling to prevent off-by-one in completion mask construction - trajectory_collect_engine: include sampler logprobs in trajectory output so the GRPO learner can apply the TIS correction without a second forward pass; fix conversation mask alignment Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 90bb1c5 commit dea9da5

8 files changed

Lines changed: 419 additions & 79 deletions

File tree

tunix/generate/vllm_sampler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,13 @@ def detokenize(
349349
input_strings, request_outputs
350350
):
351351
for idx, single_output in enumerate(multi_sampling_output.outputs):
352-
# vLLM still returns 1 eos id even if we ask it to stop at eos.
353-
if single_output.token_ids[-1] == self.tokenizer.eos_id():
354-
single_output.token_ids = single_output.token_ids[:-1]
355-
single_output.logprobs = single_output.logprobs[:-1]
352+
# KEEP the eos token in the returned token_ids — needed so multi-turn
353+
# consumers (agentic engine) can reconstruct the exact sequence the
354+
# next turn's prompt was rendered from. Combined with
355+
# `include_stop_str_in_output=True`, vLLM emits one eos at the end of
356+
# each generation. Stripping it (the previous behavior) made
357+
# trainer-side concatenation miss `<|im_end|>` at every turn boundary
358+
# and produced 30+ nat sampler-trainer logp diffs.
356359

357360
out_tokens[idx].append(
358361
np.array(single_output.token_ids, dtype=np.int32)
@@ -461,6 +464,14 @@ def __call__(
461464
sampling_params.prompt_logprobs = 0
462465
sampling_params.stop_token_ids = [self.tokenizer.eos_id()]
463466
sampling_params.skip_special_tokens = True
467+
# Keep the stop token in the returned ``token_ids`` so multi-turn
468+
# consumers can reconstruct the exact sequence the model was sampled
469+
# on. This makes the trainer-side concatenation align with what
470+
# ``apply_chat_template`` produces for the next turn's prompt; without
471+
# it, the trailing ``<|im_end|>`` (or equivalent eos token) is missing
472+
# at every turn boundary in the recorded sequence, biasing logp
473+
# recomputation against the model's actual sampling context.
474+
sampling_params.include_stop_str_in_output = True
464475

465476
if top_p is not None:
466477
sampling_params.top_p = top_p

tunix/models/qwen3/model.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def block(
486486
segment_pos: jaxtyping.Array,
487487
cache: LayerCache | None,
488488
attn_mask: jaxtyping.Array | None,
489+
segment_ids: jaxtyping.Array | None = None,
489490
) -> tuple[LayerCache | None, jaxtyping.Array]:
490491
seq_len = x.shape[1]
491492

@@ -571,19 +572,59 @@ def block(
571572
shd.NamedSharding(mesh, P(shd_n, shd_t))
572573
)
573574

574-
@partial(
575-
shard_map,
576-
mesh=mesh,
577-
in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq),
578-
out_specs=shd_spec,
579-
check_rep=False,
580-
)
581-
def sharded_splash_attn(kernel, q_block, k_block, v_block):
582-
return jax.vmap(kernel)(q_block, k_block, v_block)
575+
# Per-position segment ids let splash suppress cross-segment attention
576+
# (e.g. real-token to pad-token, or sequence-packing cross-boundary).
577+
# The pallas splash kernel only accepts a static causal mask kernel-side,
578+
# so per-batch dynamic padding masks have to flow in via segment_ids.
579+
if segment_ids is not None:
580+
seg_spec = P(shd_b, shd_t)
581+
unsharded_seg_spec = P(shd_b, None)
582+
583+
@partial(
584+
shard_map,
585+
mesh=mesh,
586+
in_specs=(
587+
kernel_spec,
588+
shd_spec,
589+
unsharded_seq,
590+
unsharded_seq,
591+
seg_spec,
592+
unsharded_seg_spec,
593+
),
594+
out_specs=shd_spec,
595+
check_rep=False,
596+
)
597+
def sharded_splash_attn(
598+
kernel, q_block, k_block, v_block, q_seg_block, kv_seg_block
599+
):
600+
seg_ids = splash.SegmentIds(q=q_seg_block, kv=kv_seg_block)
601+
return jax.vmap(kernel)(
602+
q_block, k_block, v_block, segment_ids=seg_ids
603+
)
604+
605+
qkv = sharded_splash_attn(
606+
splash_attn_kernel,
607+
query_proj,
608+
key_proj,
609+
value_proj,
610+
segment_ids,
611+
segment_ids,
612+
)
613+
else:
583614

584-
qkv = sharded_splash_attn(
585-
splash_attn_kernel, query_proj, key_proj, value_proj
586-
)
615+
@partial(
616+
shard_map,
617+
mesh=mesh,
618+
in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq),
619+
out_specs=shd_spec,
620+
check_rep=False,
621+
)
622+
def sharded_splash_attn(kernel, q_block, k_block, v_block):
623+
return jax.vmap(kernel)(q_block, k_block, v_block)
624+
625+
qkv = sharded_splash_attn(
626+
splash_attn_kernel, query_proj, key_proj, value_proj
627+
)
587628
qkv = qkv.transpose(0, 2, 1, 3)
588629
else:
589630
# GQA
@@ -621,6 +662,7 @@ def __call__(
621662
segment_pos: jaxtyping.Array,
622663
cache: LayerCache | None,
623664
attn_mask: jaxtyping.Array | None,
665+
segment_ids: jaxtyping.Array | None = None,
624666
) -> tuple[LayerCache | None, jaxtyping.Array]:
625667
if (
626668
self.config.remat_config == RematConfig.BLOCK
@@ -629,10 +671,10 @@ def __call__(
629671
# nnx.remat needs to be applied to the unbound function and take self
630672
# as the first argument.
631673
return nnx.remat(self.block.__func__)(
632-
self, x, segment_pos, cache, attn_mask
674+
self, x, segment_pos, cache, attn_mask, segment_ids
633675
)
634676
else:
635-
return self.block(x, segment_pos, cache, attn_mask)
677+
return self.block(x, segment_pos, cache, attn_mask, segment_ids=segment_ids)
636678

637679
@property
638680
def head_dim(self):
@@ -1052,13 +1094,15 @@ def block(
10521094
segment_pos: jaxtyping.Array,
10531095
cache: LayerCache | None,
10541096
attn_mask: jaxtyping.Array,
1097+
segment_ids: jaxtyping.Array | None = None,
10551098
) -> tuple[LayerCache | None, jaxtyping.Array]:
10561099
inputs_normalized = self.input_layernorm(x)
10571100
cache, attn_output = self.attn(
10581101
inputs_normalized,
10591102
segment_pos,
10601103
cache,
10611104
attn_mask,
1105+
segment_ids=segment_ids,
10621106
)
10631107
attn_output += x
10641108
residual = attn_output
@@ -1073,14 +1117,19 @@ def __call__(
10731117
segment_pos: jaxtyping.Array,
10741118
cache: LayerCache | None,
10751119
attn_mask: jaxtyping.Array,
1120+
segment_ids: jaxtyping.Array | None = None,
10761121
) -> tuple[LayerCache | None, jaxtyping.Array]:
10771122
if (
10781123
self.config.remat_config == RematConfig.DECODER
10791124
or self.config.remat_config == RematConfig.DECODER.value
10801125
):
1081-
return nnx.remat(self.block.__func__)(self, x, segment_pos, cache, attn_mask)
1126+
return nnx.remat(self.block.__func__)(
1127+
self, x, segment_pos, cache, attn_mask, segment_ids
1128+
)
10821129
else:
1083-
return self.block(x, segment_pos, cache, attn_mask)
1130+
return self.block(
1131+
x, segment_pos, cache, attn_mask, segment_ids=segment_ids
1132+
)
10841133

10851134

10861135
class Qwen3(BackendMappingMixin, nnx.Module):
@@ -1146,6 +1195,7 @@ def __call__(
11461195
cache: Cache | None, # (sequence length L')
11471196
attention_mask: jaxtyping.Array, # [B, L, L']
11481197
output_hidden_states: bool = False,
1198+
segment_ids: jaxtyping.Array | None = None, # [B, L]
11491199
) -> tuple[jaxtyping.Array, Cache | None]:
11501200
"""Qwen3 model.
11511201
@@ -1155,6 +1205,11 @@ def __call__(
11551205
cache: Attention KV cache or None.
11561206
attention_mask: transformer input mask.
11571207
output_hidden_states: whether to output the hidden states.
1208+
segment_ids: optional per-position segment identifiers, [B, L]. Used by
1209+
flash attention to suppress cross-segment attention (e.g. real-token
1210+
to pad-token, or sequence-packing across document boundaries). Pass a
1211+
1/0 mask to skip pad positions; pass increasing integer ids per packed
1212+
document for sequence packing.
11581213
11591214
Returns:
11601215
predicted_logits, new_cache
@@ -1173,6 +1228,7 @@ def __call__(
11731228
positions,
11741229
layer_cache,
11751230
attention_mask,
1231+
segment_ids=segment_ids,
11761232
)
11771233
if cache is not None:
11781234
new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch

0 commit comments

Comments
 (0)