diff --git a/tunix/generate/sampler.py b/tunix/generate/sampler.py index 4c3546aa2..80e2f68e5 100644 --- a/tunix/generate/sampler.py +++ b/tunix/generate/sampler.py @@ -119,8 +119,10 @@ def sample_top_p( # Upcast to float32 for numerical stability of softmax and subsequent cumsum. next_token_logits = logits[:, -1].astype(jnp.float32) / temperature + # top_k=0 or None both mean "no top-k filtering" — use full vocabulary. + _no_topk = top_k is None or top_k <= 0 # Skip softmax and sorting if top_p is 1.0 and top_k is full vocab. - if top_p >= 1.0 and top_k is None: + if top_p >= 1.0 and _no_topk: next_token = jax.random.categorical(key, logits=next_token_logits) if not return_logprobs: return next_token, None @@ -130,7 +132,7 @@ def sample_top_p( return next_token, logp_sampled probs = jax.nn.softmax(next_token_logits, axis=-1) - k = probs.shape[-1] if top_k is None else top_k + k = probs.shape[-1] if _no_topk else top_k probs_sorted, indices = jax.lax.top_k(probs, k=k) cumsum_probs = jnp.cumsum(probs_sorted, axis=-1) diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index afe7dd930..54fc14365 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -349,10 +349,13 @@ def detokenize( input_strings, request_outputs ): for idx, single_output in enumerate(multi_sampling_output.outputs): - # vLLM still returns 1 eos id even if we ask it to stop at eos. - if single_output.token_ids[-1] == self.tokenizer.eos_id(): - single_output.token_ids = single_output.token_ids[:-1] - single_output.logprobs = single_output.logprobs[:-1] + # KEEP the eos token in the returned token_ids — needed so multi-turn + # consumers (agentic engine) can reconstruct the exact sequence the + # next turn's prompt was rendered from. Combined with + # `include_stop_str_in_output=True`, vLLM emits one eos at the end of + # each generation. Stripping it (the previous behavior) made + # trainer-side concatenation miss `<|im_end|>` at every turn boundary + # and produced 30+ nat sampler-trainer logp diffs. out_tokens[idx].append( np.array(single_output.token_ids, dtype=np.int32) @@ -461,6 +464,14 @@ def __call__( sampling_params.prompt_logprobs = 0 sampling_params.stop_token_ids = [self.tokenizer.eos_id()] sampling_params.skip_special_tokens = True + # Keep the stop token in the returned ``token_ids`` so multi-turn + # consumers can reconstruct the exact sequence the model was sampled + # on. This makes the trainer-side concatenation align with what + # ``apply_chat_template`` produces for the next turn's prompt; without + # it, the trailing ``<|im_end|>`` (or equivalent eos token) is missing + # at every turn boundary in the recorded sequence, biasing logp + # recomputation against the model's actual sampling context. + sampling_params.include_stop_str_in_output = True if top_p is not None: sampling_params.top_p = top_p diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index 66403233a..f609d6373 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -486,6 +486,7 @@ def block( segment_pos: jaxtyping.Array, cache: LayerCache | None, attn_mask: jaxtyping.Array | None, + segment_ids: jaxtyping.Array | None = None, ) -> tuple[LayerCache | None, jaxtyping.Array]: seq_len = x.shape[1] @@ -571,19 +572,59 @@ def block( shd.NamedSharding(mesh, P(shd_n, shd_t)) ) - @partial( - shard_map, - mesh=mesh, - in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq), - out_specs=shd_spec, - check_rep=False, - ) - def sharded_splash_attn(kernel, q_block, k_block, v_block): - return jax.vmap(kernel)(q_block, k_block, v_block) + # Per-position segment ids let splash suppress cross-segment attention + # (e.g. real-token to pad-token, or sequence-packing cross-boundary). + # The pallas splash kernel only accepts a static causal mask kernel-side, + # so per-batch dynamic padding masks have to flow in via segment_ids. + if segment_ids is not None: + seg_spec = P(shd_b, shd_t) + unsharded_seg_spec = P(shd_b, None) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + shd_spec, + unsharded_seq, + unsharded_seq, + seg_spec, + unsharded_seg_spec, + ), + out_specs=shd_spec, + check_rep=False, + ) + def sharded_splash_attn( + kernel, q_block, k_block, v_block, q_seg_block, kv_seg_block + ): + seg_ids = splash.SegmentIds(q=q_seg_block, kv=kv_seg_block) + return jax.vmap(kernel)( + q_block, k_block, v_block, segment_ids=seg_ids + ) + + qkv = sharded_splash_attn( + splash_attn_kernel, + query_proj, + key_proj, + value_proj, + segment_ids, + segment_ids, + ) + else: - qkv = sharded_splash_attn( - splash_attn_kernel, query_proj, key_proj, value_proj - ) + @partial( + shard_map, + mesh=mesh, + in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq), + out_specs=shd_spec, + check_rep=False, + ) + def sharded_splash_attn(kernel, q_block, k_block, v_block): + return jax.vmap(kernel)(q_block, k_block, v_block) + + qkv = sharded_splash_attn( + splash_attn_kernel, query_proj, key_proj, value_proj + ) qkv = qkv.transpose(0, 2, 1, 3) else: # GQA @@ -621,6 +662,7 @@ def __call__( segment_pos: jaxtyping.Array, cache: LayerCache | None, attn_mask: jaxtyping.Array | None, + segment_ids: jaxtyping.Array | None = None, ) -> tuple[LayerCache | None, jaxtyping.Array]: if ( self.config.remat_config == RematConfig.BLOCK @@ -629,10 +671,10 @@ def __call__( # nnx.remat needs to be applied to the unbound function and take self # as the first argument. return nnx.remat(self.block.__func__)( - self, x, segment_pos, cache, attn_mask + self, x, segment_pos, cache, attn_mask, segment_ids ) else: - return self.block(x, segment_pos, cache, attn_mask) + return self.block(x, segment_pos, cache, attn_mask, segment_ids=segment_ids) @property def head_dim(self): @@ -1052,6 +1094,7 @@ def block( segment_pos: jaxtyping.Array, cache: LayerCache | None, attn_mask: jaxtyping.Array, + segment_ids: jaxtyping.Array | None = None, ) -> tuple[LayerCache | None, jaxtyping.Array]: inputs_normalized = self.input_layernorm(x) cache, attn_output = self.attn( @@ -1059,6 +1102,7 @@ def block( segment_pos, cache, attn_mask, + segment_ids=segment_ids, ) attn_output += x residual = attn_output @@ -1073,14 +1117,19 @@ def __call__( segment_pos: jaxtyping.Array, cache: LayerCache | None, attn_mask: jaxtyping.Array, + segment_ids: jaxtyping.Array | None = None, ) -> tuple[LayerCache | None, jaxtyping.Array]: if ( self.config.remat_config == RematConfig.DECODER or self.config.remat_config == RematConfig.DECODER.value ): - return nnx.remat(self.block.__func__)(self, x, segment_pos, cache, attn_mask) + return nnx.remat(self.block.__func__)( + self, x, segment_pos, cache, attn_mask, segment_ids + ) else: - return self.block(x, segment_pos, cache, attn_mask) + return self.block( + x, segment_pos, cache, attn_mask, segment_ids=segment_ids + ) class Qwen3(BackendMappingMixin, nnx.Module): @@ -1146,6 +1195,7 @@ def __call__( cache: Cache | None, # (sequence length L') attention_mask: jaxtyping.Array, # [B, L, L'] output_hidden_states: bool = False, + segment_ids: jaxtyping.Array | None = None, # [B, L] ) -> tuple[jaxtyping.Array, Cache | None]: """Qwen3 model. @@ -1155,6 +1205,11 @@ def __call__( cache: Attention KV cache or None. attention_mask: transformer input mask. output_hidden_states: whether to output the hidden states. + segment_ids: optional per-position segment identifiers, [B, L]. Used by + flash attention to suppress cross-segment attention (e.g. real-token + to pad-token, or sequence-packing across document boundaries). Pass a + 1/0 mask to skip pad positions; pass increasing integer ids per packed + document for sequence packing. Returns: predicted_logits, new_cache @@ -1173,6 +1228,7 @@ def __call__( positions, layer_cache, attention_mask, + segment_ids=segment_ids, ) if cache is not None: new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index e11d31760..375c91879 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -111,6 +111,20 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig): True # Whether to mask out degenerate groups with all-0 advantages. ) use_rollout_logps: bool = True + # Truncated importance-sampling (TIS) correction for the residual mismatch + # between the rollout sampler and the trainer's recomputed log-probabilities. + # Set to ``"token"`` to enable per-token TIS weights. When enabled, the loss + # path uses the trainer's start-of-step recomputed logp as + # ``old_per_token_logps`` (so the PPO ratio is taken against the trainer's + # own policy at step start, rather than directly against the sampler's logp) + # and multiplies each per-token pg-loss term by a detached weight + # w_t = clip(exp(clip(trainer_logp_t - sampler_logp_t, ±20)), max=threshold) + # dampening positions where the trainer's recomputed probability disagrees + # significantly with the rollout sampler. Without this correction, importance + # ratios computed directly against the sampler's logp can spike on outlier + # tokens, producing large-variance gradient updates. + sampler_is: str | None = None # None | "token" + sampler_is_threshold: float = 2.0 def __post_init__(self): if self.num_generations <= 1: @@ -250,6 +264,18 @@ def __init__( "pg_clipfrac": np.mean, "ppo_kl": np.mean, "kl_loss": np.mean, + "is_ratio/mean": np.mean, + "is_ratio/max": np.max, + "is_ratio/min": np.min, + "log_ratio/abs_mean": np.mean, + "pg_loss/unclipped_mean": np.mean, + "pg_loss/clipped_mean": np.mean, + "advantage/abs_mean": np.mean, + "advantage/max": np.max, + "advantage/min": np.min, + "advantage/nonzero_frac": np.mean, + "sampler_is/weight_mean": np.mean, + "sampler_is/weight_min": np.min, }) self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ lambda: "kl" @@ -396,19 +422,66 @@ def _process_results( completion_ids.shape, ) + # Sampler-trainer log-probability mismatch diagnostic. We always compute + # the trainer's recomputed logprobs whenever rollout logprobs are present + # so the per-batch diff, max, and Pearson correlation metrics can be + # logged below. Training itself still uses whichever logp source is + # configured via ``use_rollout_logps``. Cost: one extra trainer forward + # pass per training step. + rollout_per_token_logps = None + trainer_per_token_logps = None if self.algo_config.use_rollout_logps and padded_old_logprobs: - old_per_token_logps = jnp.asarray(padded_old_logprobs) + rollout_per_token_logps = jnp.asarray(padded_old_logprobs) + # NOTE: pass a NON-PADDING mask (1 for both assistant AND env tokens) for + # attention/position construction inside compute_per_token_logps, not the + # assistant-vs-env mask. process_ids uses `completion_mask` to build the + # causal attention pattern AND to drive `build_positions_from_mask`. If + # we pass the asst-vs-env mask, env tokens get masked OUT of attention + # AND positions don't advance through them — so when predicting the + # first assistant token of turn k+1, the trainer's model has no memory + # of the env observation that triggered turn k+1. That makes the + # trainer's logp at multi-turn boundaries diverge from vllm's (which + # always sees the full conversation in attention) by 30-50 nat. + attn_completion_mask = (completion_ids != pad_value).astype(jnp.int32) + trainer_per_token_logps = self.rl_cluster.get_actor_per_token_logps( + prompt_tokens=prompt_ids, + completion_tokens=completion_ids, + pad_id=pad_value, + eos_id=eos_value, + micro_batch_size=self.rl_cluster.cluster_config.training_config.compute_logps_micro_batch_size, + completion_mask=attn_completion_mask, + ) + old_per_token_logps = rollout_per_token_logps + # When sampler-IS correction is enabled, use the trainer's recomputed + # logp as ``old_per_token_logps`` so the PPO ratio is + # ``exp(current_logp - trainer_logp)`` rather than against the rollout + # sampler's logp directly. The IS weight computed below corrects for + # the trainer-vs-sampler divergence. + if self.algo_config.sampler_is == "token": + old_per_token_logps = trainer_per_token_logps elif self.algo_config.use_rollout_logps: old_per_token_logps = None else: - old_per_token_logps = self.rl_cluster.get_actor_per_token_logps( + # NOTE: pass a NON-PADDING mask (1 for both assistant AND env tokens) for + # attention/position construction inside compute_per_token_logps, not the + # assistant-vs-env mask. process_ids uses `completion_mask` to build the + # causal attention pattern AND to drive `build_positions_from_mask`. If + # we pass the asst-vs-env mask, env tokens get masked OUT of attention + # AND positions don't advance through them — so when predicting the + # first assistant token of turn k+1, the trainer's model has no memory + # of the env observation that triggered turn k+1. That makes the + # trainer's logp at multi-turn boundaries diverge from vllm's (which + # always sees the full conversation in attention) by 30-50 nat. + attn_completion_mask = (completion_ids != pad_value).astype(jnp.int32) + trainer_per_token_logps = self.rl_cluster.get_actor_per_token_logps( prompt_tokens=prompt_ids, completion_tokens=completion_ids, pad_id=pad_value, eos_id=eos_value, - micro_batch_size=None, - completion_mask=completion_mask, + micro_batch_size=self.rl_cluster.cluster_config.training_config.compute_logps_micro_batch_size, + completion_mask=attn_completion_mask, ) + old_per_token_logps = trainer_per_token_logps if self.algo_config.num_iterations > 1 and old_per_token_logps is None: raise RuntimeError( @@ -440,7 +513,7 @@ def _process_results( completion_tokens=completion_ids, pad_id=pad_value, eos_id=eos_value, - micro_batch_size=None, + micro_batch_size=self.rl_cluster.cluster_config.training_config.compute_logps_micro_batch_size, completion_mask=completion_mask, ) interval_v2.async_end([ref_per_token_logps]) @@ -523,6 +596,109 @@ def _process_results( "rewards/advantage/std": (np.std(advantages), np.mean), } + # Per-token sampler-vs-trainer log-probability agreement diagnostic. When + # this diverges from zero, importance ratios used in the policy update + # are biased and gradient quality degrades. A mean per-token diff well + # under 0.01 nat indicates the trainer and rollout sampler are computing + # log-probabilities consistently. + if ( + rollout_per_token_logps is not None + and trainer_per_token_logps is not None + ): + # ``completion_mask`` is the assistant-vs-env mask built upstream + # (1 for assistant-generated tokens, 0 for env-injected tokens), and + # already correctly scopes the comparison to model-emitted positions. + # We deliberately do NOT additionally drop positions where the rollout + # logprob equals exactly 0.0 — that value can legitimately occur for + # near-certain tokens (e.g. format chars after a structured response) + # and excluding them removes the most consistent positions from the + # statistic, inflating the per-position mean. + mask = completion_mask.astype(jnp.bool_) + mask_f = mask.astype(jnp.float32) + mask_sum = jnp.maximum(mask_f.sum(), 1.0) + diff = jnp.abs(rollout_per_token_logps - trainer_per_token_logps) + diff_mean = float((diff * mask_f).sum() / mask_sum) + diff_max = float(jnp.where(mask, diff, 0.0).max()) + # Per-position probability-space diff |exp(rollout) - exp(trainer)|. + # More representative than logp_diff for confidence agreement: logp can + # diverge arbitrarily for very low-probability tokens while their + # contribution to the importance ratio is negligible. prob_diff weights + # each position by its actual probability mass. + rp = jnp.exp(rollout_per_token_logps) + tp = jnp.exp(trainer_per_token_logps) + prob_diff = jnp.abs(rp - tp) + prob_diff_mean = float((prob_diff * mask_f).sum() / mask_sum) + prob_diff_max = float(jnp.where(mask, prob_diff, 0.0).max()) + # Pearson correlation between exp(logp) at masked positions. + rp_flat = rp.reshape(-1) + tp_flat = tp.reshape(-1) + mf = mask_f.reshape(-1) + rp_mean = (rp_flat * mf).sum() / mask_sum + tp_mean = (tp_flat * mf).sum() / mask_sum + rp_d = (rp_flat - rp_mean) * mf + tp_d = (tp_flat - tp_mean) * mf + cov = (rp_d * tp_d).sum() / mask_sum + rp_var = (rp_d * rp_d).sum() / mask_sum + tp_var = (tp_d * tp_d).sum() / mask_sum + pearson = float(cov / jnp.sqrt(jnp.maximum(rp_var * tp_var, 1e-12))) + metrics_to_log.update({ + "sampler_trainer/logp_diff_mean": (diff_mean, np.mean), + "sampler_trainer/logp_diff_max": (diff_max, np.max), + "sampler_trainer/prob_diff_mean": (prob_diff_mean, np.mean), + "sampler_trainer/prob_diff_max": (prob_diff_max, np.max), + "sampler_trainer/probs_pearson_corr": (pearson, np.mean), + }) + logging.info( + "sampler-trainer: logp_diff=(%.5f,%.5f) prob_diff=(%.5f,%.5f)" + " pearson=%.5f", + diff_mean, diff_max, prob_diff_mean, prob_diff_max, pearson, + ) + # Truncated importance-sampling (TIS) correction weights. + # Compute per-token TIS weights from the trainer-vs-sampler log-ratio, + # mask to assistant tokens only (we dampen offending model-emitted + # positions, not env tokens), clamp at the configured threshold, and + # detach. The policy loss picks these up via + # ``train_example.sampler_is_weights``. + sampler_is_weights = None + if ( + self.algo_config.sampler_is == "token" + and rollout_per_token_logps is not None + and trainer_per_token_logps is not None + ): + asst_mask_f = completion_mask.astype(jnp.float32) + log_ratio = trainer_per_token_logps - rollout_per_token_logps + log_ratio = jnp.clip(log_ratio, min=-20.0, max=20.0) + sampler_is_weights = jax.lax.stop_gradient( + jnp.minimum( + jnp.exp(log_ratio), + self.algo_config.sampler_is_threshold, + ) + * asst_mask_f + ) + mask_sum = jnp.maximum(asst_mask_f.sum(), 1.0) + is_mean = float((sampler_is_weights * asst_mask_f).sum() / mask_sum) + is_max = float(jnp.where(asst_mask_f > 0, sampler_is_weights, 0.0).max()) + frac_clipped = float( + ( + ( + (jnp.exp(log_ratio) > self.algo_config.sampler_is_threshold) + & (asst_mask_f > 0) + ).astype(jnp.float32) + ).sum() + / mask_sum + ) + metrics_to_log.update({ + "sampler_is/weight_mean": (is_mean, np.mean), + "sampler_is/weight_max": (is_max, np.max), + "sampler_is/frac_clipped_at_threshold": (frac_clipped, np.mean), + }) + logging.info( + "sampler_is: weight_mean=%.4f weight_max=%.4f frac_clipped=%.4f" + " (threshold=%.2f)", + is_mean, is_max, frac_clipped, + self.algo_config.sampler_is_threshold, + ) + # Extract time metrics (env_time and reward_time) for time_key in ["env_time", "reward_time"]: prefix = f"trajectory/{time_key}" @@ -567,6 +743,7 @@ def _process_results( advantages=advantages, old_per_token_logps=old_per_token_logps, policy_version=policy_versions, + sampler_is_weights=sampler_is_weights, ) return [combined_batch] diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index b9bc3f65e..1ca8781d9 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -189,6 +189,18 @@ def __init__( # Current iter steps for micro-batch based training. self._iter_steps = self.rl_cluster.actor_trainer.iter_steps self._eval_iter_steps = 0 + # Tracks the last train_step value at which evaluation was run. The + # optimizer is wrapped in ``optax.MultiSteps(grad_accum_steps)``, which + # keeps ``actor_trainer.train_steps`` constant for ``grad_accum_steps`` + # consecutive micro-iterations. Without this guard, the + # ``train_steps % eval_every_n_steps == 0`` check would fire at every + # micro-iteration during an eval boundary, causing the full evaluation + # rollout to be replayed ``grad_accum_steps`` times for the same step. + # Initialized to 0 (not -1) so the eval-at-step-0 baseline pass is + # skipped; it adds ~1-2 min to startup without exercising any trained + # weights. Subsequent evals at train_steps == eval_every_n_steps still + # fire normally. + self._last_eval_train_step = 0 # Sync weights if the actor model and rollout model are not sharing weights. self.should_sync_weights = not ( @@ -238,6 +250,16 @@ def run_loop_forever(): self.loop = loop_queue.get() self._global_step_start_time = time.time() + # Per-step reward accumulators populated inside ``_compute_rewards``. + # Drained at the global-step boundary to emit a one-line per-step + # summary that mirrors what an external metric logger would show. + # Each bin keeps at most ``full_batch_size``-worth of recent values + # so a producer that races one batch ahead of the consumer does not + # double-count. + self._train_rewards_window: List[float] = [] + self._eval_rewards_window: List[float] = [] + self._rewards_window_lock = threading.Lock() + def _validate_rollout_config(self): """Validates that the rollout config is properly aligned with the algo config.""" rollout_config = self.rl_cluster.cluster_config.rollout_config @@ -314,6 +336,25 @@ def _compute_rewards( rewards_info["log_metrics"], mode=mode, step=expected_step ) + rewards_array = np.asarray(rewards_info["rewards"]) + with self._rewards_window_lock: + target = ( + self._train_rewards_window + if mode == rl_cluster_lib.Mode.TRAIN + else self._eval_rewards_window + ) + target.extend(rewards_array.tolist()) + # Cap train window at full_batch_size * num_generations (one full step's + # worth of per-sequence rewards) to bound the producer-vs-consumer + # race: the producer can race up to ``off_policy_steps + 1`` batches + # ahead, so without a cap the window would over-count next-step rewards + # at the current step's boundary. + if mode == rl_cluster_lib.Mode.TRAIN and self._full_batch_size > 0: + cap = self._full_batch_size * self.algo_config.num_generations + excess = len(target) - cap + if excess > 0: + del target[:excess] + return rewards_info["rewards"] def _create_micro_batch_iterator( @@ -774,6 +815,7 @@ def train( ) micro_batches_since_last_sync = 0 micro_batches_per_full_batch = full_batch_size // train_micro_batch_size + did_eval_this_global_step = False for train_micro_batch in train_data_gen: if ( self._training_config.max_steps @@ -813,12 +855,13 @@ def train( # --- Evaluation Logic --- current_eval_dataset = None + current_train_step = self.rl_cluster.actor_trainer.train_steps if ( all_eval_prompts - and self.rl_cluster.actor_trainer.train_steps - % training_config.eval_every_n_steps - == 0 + and current_train_step % training_config.eval_every_n_steps == 0 + and current_train_step != self._last_eval_train_step ): + self._last_eval_train_step = current_train_step self._eval_iter_steps = 0 eval_orchestrator = self._build_orchestrator() @@ -842,21 +885,49 @@ async def _eval_runner_async(current_eval_orchestrator): eval_examples = eval_future.result() self._eval_iter_steps += 1 current_eval_dataset = eval_examples + did_eval_this_global_step = True # --- Training Step --- iterations = self._num_iterations() if self._process_in_consumer else 1 + # When ``train_micro_batch_size < mini_batch_size`` we want the trainer + # to invoke ``train_step`` multiple times per outer iteration so the + # optimizer (which fires every ``gradient_accumulation_steps`` micro- + # steps) sees ``mini_batch_size``-shaped gradients while peak HBM is + # only ``train_micro_batch_size``-shaped. Slice the merged train + # example along its batch axis into chunks sized to one micro-step, + # and pass the list to ``update_actor``; ``peft_trainer.train`` + # iterates the list and calls ``train_step`` once per chunk. + seqs_per_chunk = ( + train_micro_batch_size * self.algo_config.num_generations + ) + n_total = merged_train_micro_batch.completion_ids.shape[0] + if n_total > seqs_per_chunk: + chunked_train_micro_batch = [ + jax.tree_util.tree_map( + lambda x: ( + x[i : i + seqs_per_chunk] + if hasattr(x, "shape") and x.shape and x.shape[0] == n_total + else x + ), + merged_train_micro_batch, + ) + for i in range(0, n_total, seqs_per_chunk) + ] + else: + chunked_train_micro_batch = [merged_train_micro_batch] + for i in range(iterations): if self._process_in_consumer and i > 0: # TODO(b/483779605) Sub-step checkpointing. self._iter_steps += 1 self.rl_cluster.update_actor( - [merged_train_micro_batch], current_eval_dataset, skip_jit + chunked_train_micro_batch, current_eval_dataset, skip_jit ) if hasattr(self.rl_cluster, "critic_trainer"): self.rl_cluster.update_critic( - [merged_train_micro_batch], current_eval_dataset, skip_jit + chunked_train_micro_batch, current_eval_dataset, skip_jit ) # --- Weight Sync Logic --- @@ -867,6 +938,86 @@ async def _eval_runner_async(current_eval_orchestrator): f"Global step {self.rl_cluster.global_steps} completed in" f" {global_step_time:.2f} seconds." ) + # One-line per-step diagnostic: raw rewards, solve rate, completion + # length, advantage scale, and eval (when an eval just fired this + # step). Mirrors the per-iter view a wandb dashboard would show + # without depending on the async metric logger pipeline. + with self._rewards_window_lock: + train_rewards = np.asarray(self._train_rewards_window, dtype=np.float32) + eval_rewards = np.asarray(self._eval_rewards_window, dtype=np.float32) + self._train_rewards_window.clear() + if did_eval_this_global_step: + self._eval_rewards_window.clear() + adv = np.asarray(merged_train_micro_batch.advantages, dtype=np.float32) + cmask = np.asarray( + merged_train_micro_batch.completion_mask, dtype=np.float32 + ) + compl_len = cmask.sum(axis=-1).mean() if cmask.size else 0.0 + adv_abs_mean = float(np.abs(adv).mean()) if adv.size else float("nan") + train_r_mean = ( + float(train_rewards.mean()) if train_rewards.size else float("nan") + ) + train_solve = ( + float((train_rewards > 0.1).mean()) + if train_rewards.size + else float("nan") + ) + if eval_rewards.size and did_eval_this_global_step: + eval_r_mean = float(eval_rewards.mean()) + eval_solve = float((eval_rewards > 0.1).mean()) + eval_str = ( + f" eval_reward={eval_r_mean:.3f}" + f" eval_solve={eval_solve:.3f}" + f" eval_n={eval_rewards.size}" + ) + else: + eval_str = "" + # Best-effort read of trainer-side per-step metrics (grad_norm, + # pg_loss, entropy, kl) directly from the actor trainer's metric + # buffer so they appear in the per-step absl log alongside the + # rollout metrics, independently of any external metric logger. + trainer_str = "" + try: + actor_trainer = self.rl_cluster.actor_trainer + trainer_buf = ( + getattr(actor_trainer, "_prev_buffered_train_metrics", None) + or getattr(actor_trainer, "_buffered_train_metrics", None) + ) + if trainer_buf is not None: + extras = [] + if trainer_buf.losses: + extras.append(f"loss={float(trainer_buf.loss):.4f}") + am = trainer_buf.additional_metrics + for key, label in ( + ("grad_norm", "grad_norm"), + ("pg_loss", "pg_loss"), + ("entropy", "entropy"), + ("kl", "kl"), + ("log_ratio/abs_mean", "log_ratio_abs"), + ("pg_clipfrac", "clipfrac"), + ): + if key in am: + vals, _ = am[key] + if vals: + v = float(np.mean([np.asarray(x) for x in vals])) + extras.append(f"{label}={v:.4f}") + if extras: + trainer_str = " " + " ".join(extras) + except Exception as e: # pylint: disable=broad-except + logging.debug("Failed to read trainer buffered metrics: %s", e) + logging.info( + "[step %d] train_reward=%.3f train_solve=%.3f n=%d" + " adv_abs_mean=%.3f compl_len=%.1f time=%.1fs%s%s", + self.rl_cluster.global_steps, + train_r_mean, + train_solve, + int(train_rewards.size), + adv_abs_mean, + float(compl_len), + global_step_time, + trainer_str, + eval_str, + ) self.rl_cluster.buffer_metrics_async( {"perf/global_step_time": (global_step_time, np.mean)}, mode=rl_cluster_lib.Mode.TRAIN, @@ -923,6 +1074,7 @@ async def _eval_runner_async(current_eval_orchestrator): mode=rl_cluster_lib.Mode.TRAIN, ) micro_batches_since_last_sync = 0 + did_eval_this_global_step = False self._global_step_start_time = time.time() _ = producer_future.result() diff --git a/tunix/rl/agentic/parser/chat_template_parser/parser.py b/tunix/rl/agentic/parser/chat_template_parser/parser.py index 46e687afa..d107b4a07 100644 --- a/tunix/rl/agentic/parser/chat_template_parser/parser.py +++ b/tunix/rl/agentic/parser/chat_template_parser/parser.py @@ -39,6 +39,10 @@ class TokenConfig: tool_end_token: str = "" tool_response_start_token: str = "" tool_response_end_token: str = "" + # Separator inserted between consecutive messages by `parse()`. Match what + # the model's `apply_chat_template` writes between successive messages so + # incremental per-message rendering can be concatenated without a fixup. + message_separator: str = "" class BaseChatTemplateParser(ABC): @@ -71,18 +75,31 @@ def parse( add_generation_prompt: bool = False, is_first_msg: bool = False, ) -> str: - """Parse messages into chat template format.""" - result = "" + """Parse messages into chat template format. + + When `is_first_msg=False` the call renders a continuation of an existing + conversation, so the result is prefixed with `message_separator` to + re-introduce the inter-message boundary that the previous turn's eot did + not emit. This lets per-message incremental rendering be concatenated + directly to prior tokens without external fixups. + """ + parts = [] if is_first_msg: - result += self._handle_first_message(messages) + first_chunk = self._handle_first_message(messages) + if first_chunk: + parts.append(first_chunk) for message in messages: - result += self._parse_message(message) + parts.append(self._parse_message(message)) if add_generation_prompt: - result += self.generation_prompt + parts.append(self.generation_prompt) + sep = self.tokens.message_separator + result = sep.join(parts) + if not is_first_msg and result: + result = sep + result return result def _handle_first_message(self, messages: List[Dict[str, str]]) -> str: @@ -157,7 +174,7 @@ def _init_tokens(self) -> TokenConfig: return TokenConfig( bos_token=self.tokenizer.bos_token, eos_token=self.tokenizer.eos_token, - eot_token="<|im_end|>\n", + eot_token="<|im_end|>", system_token="<|im_start|>system\n", user_token="<|im_start|>user\n", assistant_token=self._get_assistant_token(), @@ -165,6 +182,7 @@ def _init_tokens(self) -> TokenConfig: tool_end_token="\n", tool_response_start_token="\n", tool_response_end_token="\n", + message_separator="\n", ) def _get_assistant_token(self) -> str: diff --git a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py index 3e1c03b0a..008e6fe3b 100644 --- a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py +++ b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py @@ -273,25 +273,28 @@ async def collect(self, mode: str = "Conversation") -> Any: prompt_tokens = getattr(self.agent.trajectory, "prompt_tokens", []) for step in self.agent.trajectory.steps: - # assistant tokens - if getattr(step, "assistant_tokens", None) is not None: - conversation_tokens.append(step.assistant_tokens) + # Keep tokens/masks/logprobs appended in lockstep — a step with env_tokens + # but no vllm logprobs (initial observation, empty completion) would + # otherwise leave the logprobs array short by `len(env_tokens)` and offset + # every subsequent step. + assistant_tokens = getattr(step, "assistant_tokens", None) + env_tokens = getattr(step, "env_tokens", None) + step_logprobs = getattr(step, "logprobs", None) + if assistant_tokens is not None: + conversation_tokens.append(assistant_tokens) conversation_masks.append(step.assistant_masks) - - # env tokens - if getattr(step, "env_tokens", None) is not None: - conversation_tokens.append(step.env_tokens) + if step_logprobs is not None: + assert len(step_logprobs) == len(assistant_tokens), ( + f"Logprobs length {len(step_logprobs)} does not match assistant" + f" tokens length {len(assistant_tokens)}" + ) + logprobs.append(step_logprobs) + else: + logprobs.append(np.zeros(len(assistant_tokens))) + if env_tokens is not None: + conversation_tokens.append(env_tokens) conversation_masks.append(step.env_masks) - - # logprobs - if getattr(step, "logprobs", None) is not None: - assert len(step.logprobs) == len(step.assistant_tokens), ( - f"Logprobs length {len(step.logprobs)} does not match assistant" - f" tokens length {len(step.assistant_tokens)}" - ) - logprobs.append(step.logprobs) - if getattr(step, "env_tokens", None) is not None: - logprobs.append(np.zeros(len(step.env_tokens))) + logprobs.append(np.zeros(len(env_tokens))) conversation_tokens = [ np.asarray(tokens) diff --git a/tunix/rl/algo_core.py b/tunix/rl/algo_core.py index c4bb2887d..356cb5156 100644 --- a/tunix/rl/algo_core.py +++ b/tunix/rl/algo_core.py @@ -458,9 +458,37 @@ def grpo_loss_fn( pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss) per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss) + + # Optional truncated importance-sampling (TIS) correction for the residual + # sampler-vs-trainer log-probability mismatch. The weights are precomputed + # upstream (already detached and threshold-clipped) and applied per token + # BEFORE loss aggregation so they affect the gradient through the loss + # magnitude only, not as a stop-gradient bias on the ratio. + sampler_is_weights = getattr(train_example, "sampler_is_weights", None) + if sampler_is_weights is not None: + per_token_loss = per_token_loss * sampler_is_weights.astype(jnp.float32) + loss = common.aggregate_loss( per_token_loss, completion_mask, loss_aggregation_mode ) + # Per-token diagnostics — log only over assistant tokens (completion_mask). + is_ratio_mean = masked_mean(is_ratio, completion_mask) + is_ratio_max = jnp.max(jnp.where(completion_mask > 0, is_ratio, 0.0)) + is_ratio_min = jnp.min( + jnp.where(completion_mask > 0, is_ratio, jnp.inf) + ) + log_ratio_abs_mean = masked_mean( + jnp.abs(seq_importance_ratio), completion_mask + ) + pg_loss_1_mean = masked_mean(pg_loss_1, completion_mask) + pg_loss_2_mean = masked_mean(pg_loss_2, completion_mask) + adv_broadcast = jnp.broadcast_to(adv, completion_mask.shape) + adv_abs_mean = masked_mean(jnp.abs(adv_broadcast), completion_mask) + adv_max = jnp.max(jnp.where(completion_mask > 0, adv_broadcast, -jnp.inf)) + adv_min = jnp.min(jnp.where(completion_mask > 0, adv_broadcast, jnp.inf)) + nonzero_adv_frac = masked_mean( + (jnp.abs(adv_broadcast) > 1e-8).astype(jnp.float32), completion_mask + ) aux = { "kl": 0.0, "kl_loss": 0.0, @@ -468,7 +496,26 @@ def grpo_loss_fn( "pg_clipfrac": clipped_fraction, "ppo_kl": ppo_kl, "pg_clipfrac_lower": pg_clipfrac_lower, + "is_ratio/mean": is_ratio_mean, + "is_ratio/max": is_ratio_max, + "is_ratio/min": is_ratio_min, + "log_ratio/abs_mean": log_ratio_abs_mean, + "pg_loss/unclipped_mean": pg_loss_1_mean, + "pg_loss/clipped_mean": pg_loss_2_mean, + "advantage/abs_mean": adv_abs_mean, + "advantage/max": adv_max, + "advantage/min": adv_min, + "advantage/nonzero_frac": nonzero_adv_frac, } + if sampler_is_weights is not None: + sis = sampler_is_weights.astype(jnp.float32) + aux["sampler_is/weight_mean"] = masked_mean(sis, completion_mask) + aux["sampler_is/weight_min"] = jnp.min( + jnp.where(completion_mask > 0, sis, jnp.inf) + ) + else: + aux["sampler_is/weight_mean"] = jnp.float32(1.0) + aux["sampler_is/weight_min"] = jnp.float32(1.0) # We do not always compute KL divergence (e.g. when beta is 0.0 unless # force_compute_kl is True). if train_example.ref_per_token_logps is not None: diff --git a/tunix/rl/common.py b/tunix/rl/common.py index a8d507848..02199676c 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -105,6 +105,12 @@ class TrainExample: old_per_token_logps: jax.Array | None segment_ids: jax.Array | None = None segment_positions: jax.Array | None = None + # Truncated importance-sampling correction weights for off-policy + # correction between the rollout sampler and the trainer. Per-token, + # detached, multiplied into the policy-gradient loss BEFORE aggregation + # to dampen positions where the trainer's recomputed log-probability + # diverges from the rollout sampler's. ``None`` disables the correction. + sampler_is_weights: jax.Array | None = None def compute_kl_divergence( @@ -230,20 +236,41 @@ def process_ids( "segment_positions must be explicitly provided for packed sequences. " ) attn_mask = None # Relies on segment_ids inside the model - return prompt_completion_ids, segment_positions, attn_mask + # Packed callers supply their own segment_ids that already separate + # distinct documents in the buffer; no need for an additional padding + # mask here. + return prompt_completion_ids, segment_positions, attn_mask, None prompt_mask = prompt_tokens != pad_id - if completion_mask is None: - completion_mask = make_completion_mask(completion_tokens, eos_tok=eos_id) + # Attention/RoPE-position mask MUST include every real token in the sequence. + # The caller-provided `completion_mask` (in multi-turn agentic learners) is + # the assistant-vs-env loss mask, with 0s on env tokens — using that for + # attention causes env-observation tokens to be masked out of context AND + # positions don't advance through them, so the trainer's logits for the + # first assistant token of turn k+1 are computed without seeing the env + # observation that triggered turn k+1, producing 30-50 nat sampler-trainer + # diffs at every turn boundary. Ignore the passed-in completion_mask for + # attention purposes; loss aggregation is the caller's responsibility. + del completion_mask + attn_completion_mask = completion_tokens != pad_id prompt_completion_mask = jnp.concatenate( - [prompt_mask, completion_mask], axis=-1 + [prompt_mask, attn_completion_mask], axis=-1 ) positions = build_positions_from_mask(prompt_completion_mask) attn_mask = make_causal_attn_mask(prompt_completion_mask) - return prompt_completion_ids, positions, attn_mask + # 1-D per-position non-pad mask for the full prompt+completion sequence. + # Used as ``segment_ids`` by attention kernels that cannot consume the 2-D + # ``attn_mask`` directly (e.g. pallas splash attention takes only a causal + # mask kernel-side and respects per-position segment ids to suppress + # cross-segment attention). With pad=0 and real=1, a real position never + # attends to a pad position regardless of where padding lives in the + # sequence (typically left-padded for prompt-side alignment). + input_seg_ids = prompt_completion_mask.astype(jnp.int32) + + return prompt_completion_ids, positions, attn_mask, input_seg_ids @partial( @@ -304,7 +331,7 @@ def compute_per_token_logps( derivatives. """ model = nnx.merge(graphdef, state) - input_tokens, calculated_positions, attn_mask = process_ids( + input_tokens, calculated_positions, attn_mask, input_seg_ids = process_ids( prompt_tokens, completion_tokens, pad_id, @@ -319,8 +346,14 @@ def compute_per_token_logps( "cache": None, "attention_mask": attn_mask, } + # Pass through any segment ids so the model's attention kernel can respect + # them: caller-provided packing ids take precedence; otherwise we pass the + # per-position non-pad mask derived in ``process_ids`` so flash-attention + # variants that lack a separate padding-mask input still skip pad positions. if segment_ids is not None: model_kwargs["segment_ids"] = segment_ids + elif input_seg_ids is not None: + model_kwargs["segment_ids"] = input_seg_ids if images is not None: model_kwargs["images"] = images @@ -373,7 +406,12 @@ def compute_score( segment_positions: jax.Array | None = None, ): """Computes reward using the provided model.""" - prompt_completion_ids, calculated_positions, attn_mask = process_ids( + ( + prompt_completion_ids, + calculated_positions, + attn_mask, + input_seg_ids, + ) = process_ids( prompt_tokens, completion_tokens, pad_id, @@ -388,6 +426,8 @@ def compute_score( model_kwargs["segment_ids"] = segment_ids else: model_kwargs["attention_mask"] = attn_mask + if input_seg_ids is not None: + model_kwargs["segment_ids"] = input_seg_ids out = model(prompt_completion_ids, **model_kwargs) per_token_scores = out[0] if isinstance(out, tuple) else out diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index 89b6ea207..ee2c717ea 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -1079,8 +1079,17 @@ def get_actor_per_token_logps( eos_id: int, micro_batch_size: int | None = None, completion_mask: jax.Array | None = None, + temperature: float | None = None, ) -> jax.Array: - """Gets per-token logps from the actor model on the trainer side.""" + """Gets per-token logps from the actor model on the trainer side. + + Mirrors `get_ref_per_token_logps` — must pass through the rollout temperature + so the actor's recomputed logps match the temperature scaling used at + sampling time (otherwise log_softmax(logits/T_sample) vs log_softmax(logits) + yields a multi-nat artifact diff vs vllm's `processed_logprobs`). + """ + if temperature is None: + temperature = self.get_rollout_config(mode=Mode.TRAIN).temperature batch_size = prompt_tokens.shape[0] if batch_size == 0: raise ValueError( @@ -1109,13 +1118,17 @@ def get_actor_per_token_logps( else: dest_completion_mask = None + # Use the anchor (start-of-global-step) actor weights so old_per_token_logps + # reference the same policy vllm sampled with even when mini_batch_size < + # full_batch_size or num_iterations > 1. Only offload the live actor when + # `offload_to_cpu` is enabled cluster-wide; otherwise the host round-trip + # was both unnecessary and risked leaving stray weights pinned to host. actor_trainer_state_on_device = self._is_state_on_device( nnx.state(self.actor_trainer.model) ) - if actor_trainer_state_on_device: + if actor_trainer_state_on_device and self.cluster_config.offload_to_cpu: self._put_model_on_memory_kind(self.actor_trainer.model, "pinned_host") gc.collect() - graphdef, actor_state = nnx.split(self.actor_trainer.model) actor_pspecs = nnx.get_partition_spec(actor_state) actor_model_sharding = jax.tree.map( @@ -1146,12 +1159,13 @@ def get_actor_per_token_logps( else dest_completion_mask[batch_slice], stop_gradient=True, return_logits=False, + temperature=temperature, ) ) actor_per_token_logps = jnp.concatenate(outs, axis=0) del state gc.collect() - if actor_trainer_state_on_device: + if actor_trainer_state_on_device and self.cluster_config.offload_to_cpu: self._put_model_on_memory_kind( self.actor_trainer.model, self._default_memory_kind )