From 21de817ac8261e142581f188df92badf0c0d0492 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Mar 2026 15:56:03 -0400 Subject: [PATCH 1/3] async streaming grpo w prefetch --- trl/trainer/grpo_config.py | 73 +++ trl/trainer/grpo_data_producer.py | 310 ++++++++++ trl/trainer/grpo_trainer.py | 990 +++++++++++++++++++++++++----- 3 files changed, 1220 insertions(+), 153 deletions(-) create mode 100644 trl/trainer/grpo_data_producer.py diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 809ac441165..8ebc5cad3c4 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -160,6 +160,31 @@ class GRPOConfig(_BaseConfig): Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory usage low, but waking the engine adds host–device transfer latency. + + > Parameters that control the data producer (online rollout generation) + + use_data_producer (`bool`, *optional*, defaults to `False`): + Use the `DataProducer` protocol for rollout generation instead of the legacy `_prepare_inputs` + buffering path. When enabled, the trainer creates a `GRPODataProducer` that integrates with the + transformers `_OnlineEpochSource` for the training loop. + async_prefetch (`bool`, *optional*, defaults to `False`): + Enable asynchronous rollout prefetching. When `True`, the next rollout is produced in a background + thread while the current one is being trained on. Requires `use_data_producer=True`. Currently only + supported with a single process (`num_processes=1`). + prefetch_depth (`int`, *optional*, defaults to `1`): + Number of rollouts to produce ahead of training when `async_prefetch` is enabled. Higher values + keep the GPU more saturated but increase off-policy staleness. + streaming_partial_batch (`bool`, *optional*, defaults to `False`): + Enable verl-style streaming partial batch training. When `True`, training begins on prompt + groups as they are scored, rather than waiting for the full batch to be scored. This reduces + peak GPU memory (only one group's logits in memory at a time) and allows reward subprocess + computation to overlap with subsequent groups' scoring. Requires `use_data_producer=True`, + `async_prefetch=True`, and `scale_rewards="group"` or `"none"`. + streaming_min_groups (`int`, *optional*, defaults to `1`): + Minimum number of prompt groups to accumulate and score before yielding micro-batches for + training. Higher values give better inter-group shuffling at the cost of more latency. + Only relevant when `streaming_partial_batch=True`. + > Parameters that control the training beta (`float`, *optional*, defaults to `0.0`): @@ -564,6 +589,54 @@ class GRPOConfig(_BaseConfig): }, ) + # Parameters that control the data producer (online rollout generation) + vllm_sync_interval: int = field( + default=1, + metadata={ + "help": "Sync model weights to the vLLM server every N training steps when async_prefetch is enabled. " + "Higher values reduce sync overhead but increase off-policy staleness of generated data. " + "Set to 1 to sync every step (default)." + }, + ) + use_data_producer: bool = field( + default=False, + metadata={ + "help": "Use the DataProducer protocol for rollout generation instead of the legacy _prepare_inputs " + "buffering path. When enabled, the trainer creates a GRPODataProducer for the training loop." + }, + ) + async_prefetch: bool = field( + default=False, + metadata={ + "help": "Enable asynchronous rollout prefetching. When True, the next rollout is produced in a " + "background thread while the current one is being trained on. Requires use_data_producer=True. " + "Currently only supported with a single process." + }, + ) + prefetch_depth: int = field( + default=1, + metadata={ + "help": "Number of rollouts to produce ahead of training when async_prefetch is enabled. Higher values " + "keep the GPU more saturated but increase off-policy staleness." + }, + ) + streaming_partial_batch: bool = field( + default=False, + metadata={ + "help": "Enable verl-style streaming partial batch training. Scores and trains on prompt groups " + "incrementally instead of waiting for the full batch. Reduces peak GPU memory and enables " + "reward/scoring overlap. Requires use_data_producer=True, async_prefetch=True, and " + "scale_rewards='group' or 'none'." + }, + ) + streaming_min_groups: int = field( + default=1, + metadata={ + "help": "Minimum number of prompt groups to accumulate before yielding micro-batches. " + "Higher values give better inter-group shuffling. Only used when streaming_partial_batch=True." + }, + ) + # Parameters that control the training beta: float = field( default=0.0, diff --git a/trl/trainer/grpo_data_producer.py b/trl/trainer/grpo_data_producer.py new file mode 100644 index 00000000000..226cd548ffe --- /dev/null +++ b/trl/trainer/grpo_data_producer.py @@ -0,0 +1,310 @@ +# Copyright 2020-2026 The HuggingFace Team & Axolotl AI +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GRPODataProducer: produces GRPO training rollouts using the transformers DataProducer protocol. + +This module bridges TRL's GRPO generation pipeline with the transformers Trainer's +online-training infrastructure (``DataProducer`` / ``_OnlineEpochSource``). +""" + +from __future__ import annotations + +import logging +from functools import partial +from typing import Any + +import torch +from torch.utils.data import DataLoader, Dataset + +from transformers.data_producer import BaseDataProducer, ProducerConfig +from transformers.trainer_utils import seed_worker + +from .utils import RepeatSampler, identity, shuffle_sequence_dict + + +logger = logging.getLogger(__name__) + +class RolloutDataset(Dataset): + """A ``torch.utils.data.Dataset`` wrapping the output dict from + ``_generate_and_score_completions``. + + The output dict contains two kinds of entries: + + * **Per-sample tensors** (batch dim > 0): ``prompt_ids``, ``completion_ids``, + ``advantages``, ``old_per_token_logps``, etc. + * **Shared metadata** (scalar, 0-dim tensor, non-tensor, or sentinel): + ``num_items_in_batch``, ``_pending_policy_logps``. + + ``__getitem__`` slices per-sample tensors at the requested index and passes + shared values through unchanged. A matching collator is created via + :func:`make_rollout_collator`. + """ + + # Keys that are always treated as shared (not per-sample) regardless of type. + _ALWAYS_SHARED = frozenset({"num_items_in_batch", "_pending_policy_logps"}) + + def __init__(self, data: dict[str, Any]): + self._data = data + + # Classify keys into shared vs per-sample. + self._shared_keys: set[str] = set() + self._sample_keys: set[str] = set() + + for key, val in data.items(): + if key in self._ALWAYS_SHARED: + self._shared_keys.add(key) + elif not isinstance(val, torch.Tensor): + # Non-tensor values (lists, ints, etc.) are treated as shared. + self._shared_keys.add(key) + elif val.dim() == 0: + self._shared_keys.add(key) + else: + self._sample_keys.add(key) + + # Determine number of samples from any per-sample tensor. + self._num_samples = 0 + for key in self._sample_keys: + n = data[key].size(0) + if self._num_samples == 0: + self._num_samples = n + elif n != self._num_samples: + raise ValueError( + f"Inconsistent sample count: key '{key}' has {n} samples, " + f"expected {self._num_samples}" + ) + + if self._num_samples == 0: + raise ValueError("No per-sample tensors found in rollout data") + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx: int) -> dict[str, Any]: + item: dict[str, Any] = {} + for key in self._sample_keys: + item[key] = self._data[key][idx] + for key in self._shared_keys: + item[key] = self._data[key] + return item + + +def make_rollout_collator(shared_keys: set[str]): + """Return a collator that stacks per-sample tensors and passes shared + keys through (taken from the first element in the batch). + + Args: + shared_keys: Set of key names that should NOT be stacked. + """ + + def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]: + result: dict[str, Any] = {} + for key in batch[0]: + if key in shared_keys: + result[key] = batch[0][key] + else: + values = [item[key] for item in batch] + if isinstance(values[0], torch.Tensor): + result[key] = torch.stack(values) + else: + result[key] = values + return result + + return _collate + + +class GRPODataProducer(BaseDataProducer): + """Produces GRPO training rollouts using the trainer's generation pipeline. + + This producer is created *before* ``Trainer.__init__`` completes, so it + stores only serialisable config values at construction time. The live + trainer reference is injected later via :meth:`set_trainer`, which also + creates the prompt ``DataLoader``. + + Args: + config: :class:`ProducerConfig` controlling mini-epochs, async, etc. + prompt_dataset: The original prompt dataset (HF ``Dataset``). + num_generations: Completions per unique prompt. + generation_batch_size: Global generation batch size (``per_device * steps_per_gen * num_processes``). + train_batch_size: Per-device training batch size. + steps_per_generation: Training steps per generation round. + shuffle_dataset: Whether to shuffle prompts. + seed: Random seed for the prompt sampler. + """ + + def __init__( + self, + config: ProducerConfig, + prompt_dataset, + *, + num_generations: int, + generation_batch_size: int, + train_batch_size: int, + steps_per_generation: int, + shuffle_dataset: bool, + seed: int, + ): + super().__init__(config) + self._dataset = prompt_dataset + self._num_generations = num_generations + self._generation_batch_size = generation_batch_size + self._train_batch_size = train_batch_size + self._steps_per_generation = steps_per_generation + self._shuffle_dataset = shuffle_dataset + self._seed = seed + + # Set later via set_trainer(). + self._trainer = None + self._prompt_dl: DataLoader | None = None + self._prompt_iter = None + + def set_trainer(self, trainer) -> None: + """Inject the live trainer reference and create the prompt DataLoader. + + Must be called after ``Trainer.__init__`` completes (so that + ``trainer.accelerator`` is available). + """ + self._trainer = trainer + self._init_prompt_dataloader() + + def _init_prompt_dataloader(self) -> None: + """Create a distributed-aware prompt DataLoader using RepeatSampler. + + * ``repeat_count=1`` so each ``produce()`` call draws a fresh batch. + * ``accelerator.prepare`` adds the ``DistributedSampler`` wrapper. + * The dataloader is immediately removed from ``accelerator._dataloaders`` + to prevent checkpoint / memory-lifecycle interference. + """ + trainer = self._trainer + sampler = RepeatSampler( + data_source=self._dataset, + mini_repeat_count=self._num_generations, + batch_size=self._generation_batch_size // self._num_generations, + repeat_count=1, + shuffle=self._shuffle_dataset, + seed=self._seed, + ) + dl = DataLoader( + self._dataset, + batch_size=self._train_batch_size * self._steps_per_generation, + sampler=sampler, + collate_fn=identity, + num_workers=trainer.args.dataloader_num_workers, + pin_memory=trainer.args.dataloader_pin_memory, + persistent_workers=trainer.args.dataloader_persistent_workers, + worker_init_fn=partial( + seed_worker, + num_workers=trainer.args.dataloader_num_workers, + rank=trainer.args.process_index, + ), + ) + self._prompt_dl = trainer.accelerator.prepare(dl) + + # Don't let the accelerator track this dataloader (it's not the + # training dataloader and shouldn't be saved/restored with checkpoints). + acc_dls = trainer.accelerator._dataloaders + if self._prompt_dl in acc_dls: + acc_dls.remove(self._prompt_dl) + + self._prompt_iter = iter(self._prompt_dl) + + def _pre_produce_hook(self, inputs: list, global_step: int) -> list: + """Called before generation to allow prompt modification. + + Override in subclasses to inject new candidates, curriculum + prompts, or other prompt-level transformations. + + Args: + inputs: List of prompt dicts drawn from the dataloader. + global_step: Current training step. + + Returns: + (Possibly modified) list of prompt dicts. + """ + return inputs + + # -- produce ------------------------------------------------------------- + + def produce( + self, + model: Any, + global_step: int, + *, + skip_policy_logps: bool = False, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + **kwargs, + ) -> RolloutDataset: + """Generate a fresh GRPO training rollout. + + 1. Draw the next prompt batch from the internal prompt DataLoader. + 2. Delegate to ``trainer._generate_and_score_completions``. + 3. Shuffle the output to break prompt-group ordering. + 4. Wrap in a :class:`RolloutDataset`. + + Args: + model: Ignored (the trainer already holds a model reference). + global_step: Current training step. + skip_policy_logps: When ``True``, the generation pipeline skips + model forward passes (``old_per_token_logps``, IS ratio, + ``ref_per_token_logps``) and sets a ``_pending_policy_logps`` + sentinel. Used by ``AsyncDataProducer`` for background calls. + """ + # get the next prompt batch from iterator (start over on epoch exhaustion). + try: + inputs = next(self._prompt_iter) + except StopIteration: + self._prompt_iter = iter(self._prompt_dl) + inputs = next(self._prompt_iter) + + # Hook for subclasses to modify prompts before generation. + inputs = self._pre_produce_hook(inputs, global_step) + + # Generate completions, compute rewards & advantages. + output = self._trainer._generate_and_score_completions( + inputs, skip_policy_logps=skip_policy_logps + ) + + # Strip non-sequence metadata before shuffling. shuffle_sequence_dict + # expects every value to be a Tensor, list, or None — plain scalars + # (like the ``_pending_policy_logps: True`` sentinel or ``num_items_in_batch``) + # would cause a "not subscriptable" TypeError. + metadata = {} + for key in list(output.keys()): + val = output[key] + if not isinstance(val, (torch.Tensor, list)): + metadata[key] = output.pop(key) + elif isinstance(val, torch.Tensor) and val.dim() == 0: + metadata[key] = output.pop(key) + + # Shuffle to break prompt-group ordering before batching. + # When skip_policy_logps=True (async path), we defer the shuffle to the + # main thread — _compute_deferred_scores needs grouped (unshuffled) + # ordering to normalise advantages per prompt group. + if not skip_policy_logps: + output = shuffle_sequence_dict(output) + + # When running on a background thread (skip_policy_logps=True -> async), + # tensor creation (padding etc.) was done on this thread's CUDA stream. + # Synchronize so all data is materialised before crossing the thread + # boundary. + if skip_policy_logps and torch.cuda.is_available(): + torch.cuda.synchronize() + + # Re-attach metadata that was stripped before the shuffle. + output.update(metadata) + + return RolloutDataset(output) \ No newline at end of file diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 5e44cee18d2..4d89332197d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -71,6 +71,7 @@ from .base_trainer import _BaseTrainer from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig +from .grpo_data_producer import GRPODataProducer, RolloutDataset, make_rollout_collator from .utils import ( RepeatSampler, create_model_from_path, @@ -126,6 +127,90 @@ def reset(self, **kwargs) -> str | None: ... EnvironmentFactory = Callable[[], _SupportsReset] +class _StreamingDataLoader: + """A lazy dataloader that scores prompt groups incrementally. + Instead of scoring the entire batch upfront, this dataloader scores + ``min_groups`` prompt groups at a time, yielding micro-batches as they + are scored. This allows reward subprocess computation to overlap with + subsequent groups' policy logprob computation. + Used when ``streaming_partial_batch=True`` in GRPOConfig. + """ + + def __init__(self, dataset, trainer, batch_size, num_generations, min_groups): + self._dataset = dataset + self._trainer = trainer + self._batch_size = batch_size + self._num_gen = num_generations + self._min_groups = max(1, min_groups) + n_samples = len(dataset) + self._n_groups = n_samples // num_generations + self._n_micro_batches = n_samples // batch_size + + def __len__(self): + return self._n_micro_batches + + def __iter__(self): + data = self._dataset._data + batch_size = self._batch_size + num_gen = self._num_gen + + # Extract deferred scoring data (same as _compute_deferred_scores) + inputs = data.pop("_deferred_inputs") + prompts = data.pop("_deferred_prompts") + completions = data.pop("_deferred_completions") + completion_ids_list = data.pop("_deferred_completion_ids_list") + for key in ("_deferred_inputs", "_deferred_prompts", + "_deferred_completions", "_deferred_completion_ids_list"): + self._dataset._shared_keys.discard(key) + self._dataset._sample_keys.discard(key) + + # Clean up sentinels + del data["_pending_policy_logps"] + del data["_streaming_pending"] + self._dataset._shared_keys.discard("_pending_policy_logps") + self._dataset._shared_keys.discard("_streaming_pending") + + # Shared keys (non-sample data like num_items_in_batch) + shared_keys = set(self._dataset._shared_keys) + + # Process groups in chunks of min_groups + for chunk_start_g in range(0, self._n_groups, self._min_groups): + chunk_end_g = min(chunk_start_g + self._min_groups, self._n_groups) + s_start = chunk_start_g * num_gen + s_end = chunk_end_g * num_gen + + # Score this chunk of groups + self._trainer._compute_streaming_group_scores( + data=data, + s_start=s_start, + s_end=s_end, + inputs=inputs[s_start:s_end], + prompts=prompts[s_start:s_end], + completions=completions[s_start:s_end], + completion_ids_list=completion_ids_list[s_start:s_end], + is_last_chunk=(chunk_end_g == self._n_groups), + ) + + # Shuffle within the scored chunk, then yield micro-batches + chunk_size = s_end - s_start + perm = torch.randperm(chunk_size) + for mb_offset in range(0, chunk_size, batch_size): + mb_indices = perm[mb_offset:mb_offset + batch_size] + abs_indices = mb_indices + s_start + micro_batch = {} + for key in data: + if key.startswith("_"): + continue + val = data[key] + if key in shared_keys: + micro_batch[key] = val + elif isinstance(val, torch.Tensor) and val.dim() > 0: + micro_batch[key] = val[abs_indices] + else: + micro_batch[key] = val + yield micro_batch + + class GRPOTrainer(_BaseTrainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -593,15 +678,20 @@ def __init__( args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + data_producer = None + if args.use_data_producer: + data_producer = self._create_data_producer(args, train_dataset) + super().__init__( model=model, args=args, data_collator=identity, # No data collation is needed in GRPO - train_dataset=train_dataset, + train_dataset=train_dataset if data_producer is None else None, eval_dataset=eval_dataset, processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, + data_producer=data_producer, # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The @@ -610,6 +700,15 @@ def __init__( compute_loss_func="non-None value to disable scaling", ) + # Inject trainer reference into the data producer (needs accelerator, which is + # available only after super().__init__). + if self.data_producer is not None: + producer = self.data_producer + # Unwrap AsyncDataProducer to get the inner GRPODataProducer. + if hasattr(producer, "_inner"): + producer = producer._inner + producer.set_trainer(self) + # Reference model self.beta = args.beta if self.beta == 0.0: @@ -678,6 +777,7 @@ def cast_outputs_to_original_dtype(module, args, output): use_ref_model=self.beta != 0.0, loss_type=self.loss_type, max_completion_length=self.max_completion_length, + importance_sampling_level=self.importance_sampling_level, ) # Initialize the metrics @@ -930,6 +1030,524 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler: seed=self.args.seed, ) + def _create_data_producer(self, args, train_dataset): + """Create and return the DataProducer (possibly wrapped in AsyncDataProducer). + Override in subclasses to use a custom data producer + """ + from transformers.data_producer import ProducerConfig + + producer_config = ProducerConfig( + mini_epochs=args.num_iterations, + max_rollouts=None, # bounded by max_steps + eval_during_produce=False, # GRPO manages its own eval/train mode + empty_cache_before_produce=True, + empty_cache_after_produce=True, + async_prefetch=args.async_prefetch, + prefetch_depth=args.prefetch_depth, + ) + data_producer = GRPODataProducer( + config=producer_config, + prompt_dataset=train_dataset, + num_generations=self.num_generations, + generation_batch_size=args.generation_batch_size, + train_batch_size=args.per_device_train_batch_size, + steps_per_generation=args.steps_per_generation, + shuffle_dataset=self.shuffle_dataset, + seed=args.seed, + ) + if args.async_prefetch: + from transformers.data_producer import AsyncDataProducer + + data_producer = AsyncDataProducer( + data_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + return data_producer + + def _compute_rewards_for_batch( + self, + inputs: list, + prompts: list, + completions: list, + completion_ids_list: list, + ): + """Compute rewards for a batch of samples. + Returns a ``(batch_size, num_reward_funcs)`` tensor. + Override in subclasses to customize hoe workers get spawned to calculate rewards. + """ + return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + def _post_advantage_hook( + self, + data: dict, + rewards_per_func, + advantages, + inputs: list, + num_generations: int, + mode: str, + s_start: int | None = None, + s_end: int | None = None, + is_last_chunk: bool = True, + ) -> None: + """Called after advantages are computed, before reward metrics. + Override in subclasses to implement replay buffers, etc. + Args: + data: The mutable dataset dict. + rewards_per_func: ``(batch, num_reward_funcs)`` reward tensor. + advantages: ``(local_batch,)`` computed advantages. + inputs: Original prompt input dicts. + num_generations: Completions per unique prompt. + mode: ``"train"`` or ``"eval"``. + s_start, s_end: Sample index range (streaming path only). + is_last_chunk: Whether this is the final chunk (streaming only). + """ + + def _produce_data(self, model): + """Override to handle vLLM weight sync and policy logprob computation.""" + if self.use_vllm and self.args.async_prefetch: + sync_interval = self.args.vllm_sync_interval + if (self.state.global_step != self._last_loaded_step and + self.state.global_step % sync_interval == 0): + # Wait for in-flight background futures so we don't mutate the + # vLLM engine while an in-flight generation is running. + producer = self.data_producer + if hasattr(producer, "_queue"): + for future in list(producer._queue): + future.result() + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + dataset = super()._produce_data(model) + + if self.args.async_prefetch and torch.cuda.is_available(): + torch.cuda.synchronize() + + if isinstance(dataset, RolloutDataset) and dataset._data.get("_pending_policy_logps"): + if self.args.streaming_partial_batch: + # Streaming mode: defer scoring to the dataloader, which will score groups + # incrementally as micro-batches are consumed. + dataset._data["_streaming_pending"] = True + else: + self._compute_deferred_scores(dataset) + + return dataset + + def _get_online_dataloader(self, dataset): + """Create a DataLoader for a ``RolloutDataset`` produced by the data producer. + Unlike the base implementation, we do **not** call ``accelerator.prepare`` + because the data is already per-process and on-device (tensors were created + in ``_generate_and_score_completions`` on the correct device). + """ + if not isinstance(dataset, RolloutDataset): + return super()._get_online_dataloader(dataset) + + if self.args.streaming_partial_batch and dataset._data.get("_streaming_pending"): + return _StreamingDataLoader( + dataset=dataset, + trainer=self, + batch_size=self.args.per_device_train_batch_size, + num_generations=self.num_generations, + min_groups=self.args.streaming_min_groups, + ) + + return DataLoader( + dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=make_rollout_collator(dataset._shared_keys), + shuffle=False, # Already shuffled in produce() + ) + + @torch.no_grad() + def _compute_deferred_scores(self, dataset: RolloutDataset): + """Compute rewards, advantages, and policy logprobs on the main thread. + Called after ``_produce_data`` returns a dataset whose + ``_pending_policy_logps`` sentinel indicates that the rollout was + produced on a background thread. The background thread only performs + generation + tensor padding; this method does all the scoring work + that was deferred. + """ + data = dataset._data + device = self.accelerator.device + batch_size = self.args.per_device_train_batch_size + mode = "train" + + # ---- Recover deferred inputs for reward computation ---- + inputs = data.pop("_deferred_inputs") + prompts = data.pop("_deferred_prompts") + completions = data.pop("_deferred_completions") + completion_ids_list = data.pop("_deferred_completion_ids_list") + for key in ("_deferred_inputs", "_deferred_prompts", + "_deferred_completions", "_deferred_completion_ids_list"): + dataset._shared_keys.discard(key) + dataset._sample_keys.discard(key) + + # ---- Compute policy logprobs (GPU-bound, overlaps with BG rewards) ---- + prompt_completion_ids = torch.cat([data["prompt_ids"], data["completion_ids"]], dim=1) + attention_mask = torch.cat([data["prompt_mask"], data["completion_mask"]], dim=1) + logits_to_keep = data["completion_ids"].size(1) + + forward_kwargs = {} + for key in ("pixel_values", "image_grid_thw", "pixel_attention_mask", + "image_sizes", "token_type_ids", "mm_token_type_ids"): + if key in data: + forward_kwargs[key] = data[key] + num_images = data.get("num_images") + + # Use larger batch for inference-only logprob computation (no grad = less memory) + logprob_batch_size = min(batch_size * 4, len(prompt_completion_ids)) + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + generate_every = self.args.steps_per_generation * self.num_iterations + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + data["old_per_token_logps"], _ = self._get_per_token_logps_and_entropies( + self.model, prompt_completion_ids, attention_mask, + logits_to_keep, logprob_batch_size, num_images=num_images, **forward_kwargs, + ) + + if self.beta != 0.0: + if self.ref_model is not None: + data["ref_per_token_logps"], _ = self._get_per_token_logps_and_entropies( + self.ref_model, prompt_completion_ids, attention_mask, + logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, + ) + else: + unwrapped = self.accelerator.unwrap_model(self.model) + adapter_name = "ref" if "ref" in unwrapped.peft_config else None + with use_adapter(unwrapped, adapter_name=adapter_name): + data["ref_per_token_logps"], _ = self._get_per_token_logps_and_entropies( + self.model, prompt_completion_ids, attention_mask, + logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, + ) + + # ---- Compute rewards and advantages ---- + rewards_per_func = self._compute_rewards_for_batch( + inputs, prompts, completions, completion_ids_list + ) + + num_generations = self.num_generations + + if self.multi_objective_aggregation == "sum_then_normalize": + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0) + if self.scale_rewards in ["group", "none"]: + if num_generations > 1: + std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(num_generations, dim=0) + else: + std_rewards = torch.zeros_like(rewards) + elif self.scale_rewards == "batch": + if rewards.numel() > 1: + std_rewards = rewards.std().expand_as(rewards) + else: + std_rewards = torch.zeros_like(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. " + "Must be one of 'batch', 'group', or 'none'." + ) + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) + reward_k = (grouped - mean_k) / (std_k + 1e-4) + reward_k = reward_k.view(-1, len(self.reward_funcs)) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + else: + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. " + "Must be 'sum_then_normalize' or 'normalize_then_sum'." + ) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Replace placeholder advantages in the dataset + data["advantages"] = advantages + + # Hook for subclasses + self._post_advantage_hook( + data=data, + rewards_per_func=rewards_per_func, + advantages=advantages, + inputs=inputs, + num_generations=num_generations, + mode=mode, + ) + + # ---- Accumulate reward metrics (main thread — safe) ---- + for i, reward_func_name in enumerate(self.reward_func_names): + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{reward_func_name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) + agg_rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(agg_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # ---- Clean up sentinels & register new keys ---- + del data["_pending_policy_logps"] + dataset._shared_keys.discard("_pending_policy_logps") + + for key in ("old_per_token_logps", "ref_per_token_logps"): + dataset._shared_keys.discard(key) + if key in data: + dataset._sample_keys.add(key) + + # ---- Shuffle (deferred from produce()) ---- + shuffled = shuffle_sequence_dict(data) + dataset._data = shuffled + + @torch.no_grad() + def _compute_streaming_group_scores( + self, data, s_start, s_end, inputs, prompts, completions, completion_ids_list, is_last_chunk + ): + """Score a chunk of prompt groups: rewards, policy logprobs, advantages. + Called by ``_StreamingDataLoader`` to incrementally score groups. Writes + results directly into the ``data`` dict at positions ``s_start:s_end``. + Args: + data: The dataset._data dict (mutable). + s_start, s_end: Sample index range for this chunk. + inputs: Deferred inputs for this chunk's samples. + prompts, completions, completion_ids_list: Deferred reward inputs. + is_last_chunk: Whether this is the last chunk (for metrics/buffer logging). + """ + device = self.accelerator.device + batch_size = self.args.per_device_train_batch_size + num_generations = self.num_generations + mode = "train" + chunk_size = s_end - s_start + + # ---- Compute policy logprobs for this chunk ---- + chunk_prompt_ids = data["prompt_ids"][s_start:s_end] + chunk_completion_ids = data["completion_ids"][s_start:s_end] + chunk_prompt_mask = data["prompt_mask"][s_start:s_end] + chunk_completion_mask = data["completion_mask"][s_start:s_end] + prompt_completion_ids = torch.cat([chunk_prompt_ids, chunk_completion_ids], dim=1) + attention_mask = torch.cat([chunk_prompt_mask, chunk_completion_mask], dim=1) + logits_to_keep = chunk_completion_ids.size(1) + + forward_kwargs = {} + for key in ("pixel_values", "image_grid_thw", "pixel_attention_mask", + "image_sizes", "token_type_ids", "mm_token_type_ids"): + if key in data: + forward_kwargs[key] = data[key] + + logprob_batch_size = min(batch_size * 4, chunk_size) + with disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + generate_every = self.args.steps_per_generation * self.num_iterations + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_logps, _ = self._get_per_token_logps_and_entropies( + self.model, prompt_completion_ids, attention_mask, + logits_to_keep, logprob_batch_size, **forward_kwargs, + ) + if "old_per_token_logps" not in data: + # Initialize the full-batch tensor on first chunk + total_samples = len(data["prompt_ids"]) + data["old_per_token_logps"] = torch.zeros( + total_samples, old_logps.size(1), device=device, dtype=old_logps.dtype + ) + data["old_per_token_logps"][s_start:s_end] = old_logps + + if self.beta != 0.0: + if self.ref_model is not None: + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, prompt_completion_ids, attention_mask, + logits_to_keep, batch_size, **forward_kwargs, + ) + else: + unwrapped = self.accelerator.unwrap_model(self.model) + adapter_name = "ref" if "ref" in unwrapped.peft_config else None + with use_adapter(unwrapped, adapter_name=adapter_name): + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.model, prompt_completion_ids, attention_mask, + logits_to_keep, batch_size, **forward_kwargs, + ) + if "ref_per_token_logps" not in data: + total_samples = len(data["prompt_ids"]) + data["ref_per_token_logps"] = torch.zeros( + total_samples, ref_logps.size(1), device=device, dtype=ref_logps.dtype + ) + data["ref_per_token_logps"][s_start:s_end] = ref_logps + + # ---- Compute rewards ---- + rewards_per_func = self._compute_rewards_for_batch( + inputs, prompts, completions, completion_ids_list + ) + + # ---- Compute advantages (group-level normalization) ---- + # Streaming requires group-level normalization (scale_rewards="group" or "none") + if self.multi_objective_aggregation == "sum_then_normalize": + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0) + if num_generations > 1: + std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(num_generations, dim=0) + else: + std_rewards = torch.zeros_like(rewards) + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) + reward_k = (grouped - mean_k) / (std_k + 1e-4) + reward_k = reward_k.view(-1, len(self.reward_funcs)) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + std_rewards = rewards.view(-1, num_generations).std(dim=1).repeat_interleave(num_generations, dim=0) + mean_rewards = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations, dim=0) + advantages = (rewards - mean_rewards) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + else: + raise ValueError(f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}") + + # Slice for local process + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Write advantages into dataset + if "advantages" not in data or not isinstance(data["advantages"], torch.Tensor): + total_samples = len(data["prompt_ids"]) + data["advantages"] = torch.zeros(total_samples, device=device) + data["advantages"][s_start:s_end] = advantages + + # Hook for subclasses + self._post_advantage_hook( + data=data, + rewards_per_func=rewards_per_func, + advantages=advantages, + inputs=inputs, + num_generations=num_generations, + mode=mode, + s_start=s_start, + s_end=s_end, + is_last_chunk=is_last_chunk, + ) + + # ---- Reward metrics ---- + for i, reward_func_name in enumerate(self.reward_func_names): + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{reward_func_name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) + agg_rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(agg_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # ---- Completion length & IS metrics (parity with sync path) ---- + if is_last_chunk: + # Token count tracking + all_prompt_mask = data["prompt_mask"] + all_completion_mask = data["completion_mask"] + all_completion_ids = data["completion_ids"] + total_prompt_tokens = self.accelerator.gather(all_prompt_mask.sum()) + total_completion_tokens = self.accelerator.gather(all_completion_mask.sum()) + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Completion length metrics (use full batch, not just this chunk) + completion_lengths = all_completion_mask.sum(dim=1) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Clipped ratio and terminated lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor( + [ids[-1].item() not in eos_and_pad for ids in all_completion_ids], device=device + ) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append( + term_completion_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term_completion_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term_completion_lengths.float().max().item() + ) + + # Importance sampling metrics + if ( + self.use_vllm + and self.vllm_importance_sampling_correction + and "sampling_per_token_logps" in data + and "old_per_token_logps" in data + ): + old_logps = data["old_per_token_logps"] + sampling_logps = data["sampling_per_token_logps"] + mask = all_completion_mask.bool() + delta = torch.abs(old_logps - sampling_logps) + delta_masked = delta[mask] + mean_delta = torch.mean(delta_masked) if delta_masked.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta_masked) if delta_masked.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + + # Compute IS ratio for metrics + is_mask = all_completion_mask if "tool_mask" not in data else all_completion_mask * data["tool_mask"] + per_token_logps_diff = (old_logps - sampling_logps) * is_mask + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + else: + logps_diff = per_token_logps_diff + is_ratio = torch.exp(logps_diff) + + if sequence_level_is: + flat_is_ratio = is_ratio.flatten() + else: + flat_is_ratio = is_ratio[mask] + + if flat_is_ratio.numel() > 0: + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(torch.min(flat_is_ratio))).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(torch.mean(flat_is_ratio)).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(torch.max(flat_is_ratio))).item() + ) + @profiling_decorator def _get_last_hidden_state( self, @@ -1104,18 +1722,32 @@ def training_step(self, model, inputs, num_items_in_batch): def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: # Prepares inputs for model training/evaluation by managing completion generation and batch handling. # During training: - # - Receives the local generation batch (Per-GPU batch size × steps per generation) - # from the modified training dataloader instead of the standard local batch - # - Generates completions once for the entire generation batch and splits it into batches of size - # `per_device_train_batch_size` - # - Buffers these completions and returns the appropriate slice for the current accumulation step - # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # - DataProducer path: + # - data is already generated and batched by the RolloutDataset / DataLoader. + # - Just validate and pass through. + # - Legacy path: + # - Receives the local generation batch (Per-GPU batch size × steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) # During evaluation: # - The input is treated as a standard local batch (no accumulation, no multiple iterations) # - Completions are generated for each batch without buffering or reuse # Returns a single local batch in both cases. mode = "train" if self.model.training else "eval" + + # DataProducer path: the RolloutDataset DataLoader yields ready-to-train batches. + if mode == "train" and self.data_producer is not None: + assert "_pending_policy_logps" not in generation_batch, ( + "Batch still has pending policy logps — _compute_deferred_scores should have " + "been called on the full rollout before training started." + ) + return generation_batch + + # Legacy path: if mode == "train": generate_every = self.args.steps_per_generation * self.num_iterations if self._step % generate_every == 0 or self._buffered_inputs is None: @@ -1216,14 +1848,15 @@ async def _run_async_funcs(): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list): + def _generate_single_turn(self, prompts: list, mode: str = "train"): device = self.accelerator.device - mode = "train" if self.model.training else "eval" # Generate completions using either vLLM or regular generation if self.use_vllm: # Sync weights if training step changed - if self.state.global_step != self._last_loaded_step: + # When async_prefetch is enabled, weight sync is handled by _produce_data + # on the main thread — skip here to avoid syncing from a background thread. + if not self.args.async_prefetch and self.state.global_step != self._last_loaded_step: with profiling_context(self, "sync_weights"): self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step @@ -1329,7 +1962,7 @@ def _generate_single_turn(self, prompts: list): return prompt_ids, completion_ids, logprobs, extra_fields - def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs): + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, mode="train"): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] @@ -1433,7 +2066,7 @@ async def _run_async_tools(async_coros): # Generate new completions after tool execution prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn( - prompt_completion_tools + prompt_completion_tools, mode=mode, ) # Sanity check: from experience, this is useful to catch bugs in the chat template @@ -1501,12 +2134,19 @@ async def _run_async_tools(async_coros): def _generate(self, prompts: list): device = self.accelerator.device - mode = "train" if self.model.training else "eval" + # When called from a background thread (async DataProducer), reading + # self.model.training can cause a race condition + import threading + _on_bg_thread = threading.current_thread() is not threading.main_thread() + if _on_bg_thread: + mode = "train" + else: + mode = "train" if self.model.training else "eval" # Copy the prompts to avoid modifying the original list prompts = copy.deepcopy(prompts) - prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts, mode=mode) # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): @@ -1532,7 +2172,7 @@ def _generate(self, prompts: list): logprobs, tool_call_count, tool_failure_count, - ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) + ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs, mode=mode) else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) @@ -1550,28 +2190,31 @@ def _generate(self, prompts: list): total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - if self.tools: + # Skip on background threads to avoid race updates to + # self.state / self._metrics while the main thread trains. + if not _on_bg_thread: + if mode == "train": + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + if self.tools and not _on_bg_thread: agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) @@ -1592,9 +2235,12 @@ def _generate(self, prompts: list): ) def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] + self, inputs: list[dict[str, torch.Tensor | Any]], skip_policy_logps: bool = False, ) -> dict[str, torch.Tensor | Any]: device = self.accelerator.device + + # When skip_policy_logps=True, then we're being called by the data producer on a + # background thread. avoid reading self.model.training due to race conditions mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] @@ -1711,7 +2357,7 @@ def _generate_and_score_completions( # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). # Temporarily disable checkpointing to avoid this warning during inference. - with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + with (torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs)): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps @@ -1720,72 +2366,58 @@ def _generate_and_score_completions( # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the # distribution mismatch between vLLM and the training model can be large and harm the training. generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency - if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm and self.vllm_importance_sampling_correction - ): - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - old_per_token_logps = None - - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - if self.use_vllm and self.vllm_importance_sampling_correction: - mask = completion_mask if tool_mask is None else completion_mask * tool_mask - per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask - - sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] - if sequence_level_is: - per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) - logps_diff = per_sequence_logps_diff - else: - logps_diff = per_token_logps_diff - - vllm_importance_sampling_ratio = torch.exp(logps_diff) - - # vllm_importance_sampling_ratio.shape: - # token_* modes: (B, T) (per-token ratio) - # sequence_* modes: (B, 1) (per-sequence ratio) - - if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: - vllm_importance_sampling_ratio = torch.clamp( - vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap - ) - elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: - vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( - vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 - ) - else: - raise ValueError( - f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." - ) - - # Compute the per-token log probabilities for the reference model - if self.beta != 0.0: - if self.ref_model is not None: - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.ref_model, + if not skip_policy_logps: + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, prompt_completion_ids, attention_mask, logits_to_keep, - batch_size=batch_size, + batch_size, num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: - # When training a PEFT adapter, how we obtain the reference depends on the setup: - # - New adapter: disabling adapters yields the base model. - # - Re-training an existing adapter: an initial copy is loaded under the name "ref". - model = self.accelerator.unwrap_model(self.model) - with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + mask = completion_mask if tool_mask is None else completion_mask * tool_mask + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask + + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff + + vllm_importance_sampling_ratio = torch.exp(logps_diff) + + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 + ) + else: + raise ValueError( + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep, @@ -1793,7 +2425,25 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) + else: + # When training a PEFT adapter, how we obtain the reference depends on the setup: + # - New adapter: disabling adapters yields the base model. + # - Re-training an existing adapter: an initial copy is loaded under the name "ref". + model = self.accelerator.unwrap_model(self.model) + with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None else: + old_per_token_logps = None ref_per_token_logps = None # Decode @@ -1809,6 +2459,38 @@ def _generate_and_score_completions( elif not isinstance(values, list): inp[key] = values + # When skip_policy_logps=True the code runs on a background thread. + # Rather than trying to make reward computation / advantage normalisation + # / distributed collectives thread-safe, we return the generation output + # immediately and defer *all* scoring to the main thread. + if skip_policy_logps: + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "num_items_in_batch": num_items_in_batch, + # placeholders + "advantages": torch.zeros(completion_ids.size(0), device=device), + "_pending_policy_logps": True, + # raw data for deferred reward / advantage computation + "_deferred_inputs": inputs, + "_deferred_prompts": prompts, + "_deferred_completions": completions, + "_deferred_completion_ids_list": completion_ids_list, + } + if sampling_per_token_logps is not None: + output["sampling_per_token_logps"] = sampling_per_token_logps + if tool_mask is not None: + output["tool_mask"] = tool_mask + if images is not None: + output["num_images"] = num_images + for k in ("pixel_values", "image_grid_thw", "pixel_attention_mask", + "image_sizes", "token_type_ids", "mm_token_type_ids"): + if k in forward_kwargs: + output[k] = forward_kwargs[k] + return output + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. @@ -1868,62 +2550,64 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) - for i, reward_func_name in enumerate(self.reward_func_names): - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) - std_func_rewards = nanstd(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) - rewards = rewards_per_func.nansum(dim=1) - self._metrics[mode]["reward"].append(rewards.mean().item()) - self._metrics[mode]["reward_std"].append(rewards.std().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) - - # Log prompt and completion texts - self._logs["prompt"].extend(gather_object(prompts_text)) - self._logs["completion"].extend(gather_object(completions_text)) - for i, name in enumerate(self.reward_func_names): - self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - self._logs["advantages"].extend(all_process_advantages.tolist()) + # don't update trainer state from background worker + if not skip_policy_logps: + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + if images is not None: + self._logs["images"].extend(gather_object(images)) - if images is not None: - self._logs["images"].extend(gather_object(images)) - - if self.use_vllm and self.vllm_importance_sampling_correction: - delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() - delta = delta[mask] - mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( - self.accelerator.gather(mean_delta).mean().item() - ) - self._metrics[mode]["sampling/sampling_logp_difference/max"].append( - self.accelerator.gather(max_delta).max().item() - ) - if sequence_level_is: - flat_is_ratio = vllm_importance_sampling_ratio.flatten() - else: - flat_is_ratio = vllm_importance_sampling_ratio[mask] + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() + delta = delta[mask] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[mask] - min_importance_sampling_ratio = ( - torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - mean_importance_sampling_ratio = ( - torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - max_importance_sampling_ratio = ( - torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() - ) + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) output = { "prompt_ids": prompt_ids, @@ -1935,7 +2619,7 @@ def _generate_and_score_completions( } if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps - if self.use_vllm and self.vllm_importance_sampling_correction: + if not skip_policy_logps and self.use_vllm and self.vllm_importance_sampling_correction: output["importance_sampling_ratio"] = vllm_importance_sampling_ratio if sampling_per_token_logps is not None: output["sampling_per_token_logps"] = sampling_per_token_logps @@ -2149,7 +2833,7 @@ def _compute_loss(self, model, inputs): if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask - if self.use_vllm and self.vllm_importance_sampling_correction: + if self.use_vllm and self.vllm_importance_sampling_correction and "importance_sampling_ratio" in inputs: per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: From 2b4276cdc5121922371e80b453ccbfd35ee0e658 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Mar 2026 17:08:23 -0400 Subject: [PATCH 2/3] address PR feedback --- trl/trainer/grpo_trainer.py | 53 ++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 4d89332197d..7d89d2f351f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -144,7 +144,13 @@ def __init__(self, dataset, trainer, batch_size, num_generations, min_groups): self._min_groups = max(1, min_groups) n_samples = len(dataset) self._n_groups = n_samples // num_generations - self._n_micro_batches = n_samples // batch_size + # Compute exact micro-batch count: each chunk of min_groups yields ceil(chunk_size / batch_size) micro-batches + n_micro = 0 + for chunk_start_g in range(0, self._n_groups, self._min_groups): + chunk_end_g = min(chunk_start_g + self._min_groups, self._n_groups) + chunk_size = (chunk_end_g - chunk_start_g) * num_generations + n_micro += -(-chunk_size // batch_size) # ceil div + self._n_micro_batches = n_micro def __len__(self): return self._n_micro_batches @@ -1348,7 +1354,14 @@ def _compute_streaming_group_scores( for key in ("pixel_values", "image_grid_thw", "pixel_attention_mask", "image_sizes", "token_type_ids", "mm_token_type_ids"): if key in data: - forward_kwargs[key] = data[key] + val = data[key] + if isinstance(val, torch.Tensor) and val.dim() > 0 and val.size(0) == len(data["prompt_ids"]): + forward_kwargs[key] = val[s_start:s_end] + else: + forward_kwargs[key] = val + num_images = data.get("num_images") + if num_images is not None and hasattr(num_images, '__getitem__') and len(num_images) == len(data["prompt_ids"]): + num_images = num_images[s_start:s_end] logprob_batch_size = min(batch_size * 4, chunk_size) with disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): @@ -1358,7 +1371,7 @@ def _compute_streaming_group_scores( ): old_logps, _ = self._get_per_token_logps_and_entropies( self.model, prompt_completion_ids, attention_mask, - logits_to_keep, logprob_batch_size, **forward_kwargs, + logits_to_keep, logprob_batch_size, num_images=num_images, **forward_kwargs, ) if "old_per_token_logps" not in data: # Initialize the full-batch tensor on first chunk @@ -1368,11 +1381,36 @@ def _compute_streaming_group_scores( ) data["old_per_token_logps"][s_start:s_end] = old_logps + # Compute importance sampling ratio for this chunk and store it + if "sampling_per_token_logps" in data: + sampling_logps_chunk = data["sampling_per_token_logps"][s_start:s_end] + is_mask = chunk_completion_mask if "tool_mask" not in data else ( + chunk_completion_mask * data["tool_mask"][s_start:s_end] + ) + per_token_logps_diff = (old_logps - sampling_logps_chunk) * is_mask + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + else: + logps_diff = per_token_logps_diff + is_ratio = torch.exp(logps_diff) + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + is_ratio = torch.clamp(is_ratio, max=self.vllm_importance_sampling_cap) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + is_ratio = is_ratio.masked_fill(is_ratio > self.vllm_importance_sampling_cap, value=0.0) + if "importance_sampling_ratio" not in data: + total_samples = len(data["prompt_ids"]) + is_shape = (total_samples, 1) if sequence_level_is else (total_samples, is_ratio.size(1)) + data["importance_sampling_ratio"] = torch.ones( + *is_shape, device=device, dtype=is_ratio.dtype + ) + data["importance_sampling_ratio"][s_start:s_end] = is_ratio + if self.beta != 0.0: if self.ref_model is not None: ref_logps, _ = self._get_per_token_logps_and_entropies( self.ref_model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, **forward_kwargs, + logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, ) else: unwrapped = self.accelerator.unwrap_model(self.model) @@ -1380,7 +1418,7 @@ def _compute_streaming_group_scores( with use_adapter(unwrapped, adapter_name=adapter_name): ref_logps, _ = self._get_per_token_logps_and_entropies( self.model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, **forward_kwargs, + logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, ) if "ref_per_token_logps" not in data: total_samples = len(data["prompt_ids"]) @@ -2241,7 +2279,10 @@ def _generate_and_score_completions( # When skip_policy_logps=True, then we're being called by the data producer on a # background thread. avoid reading self.model.training due to race conditions - mode = "train" if self.model.training else "eval" + if skip_policy_logps: + mode = "train" + else: + mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] From 91388df763b0f43c707b583477277a2e1cc01d85 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Mar 2026 17:10:04 -0400 Subject: [PATCH 3/3] remove the loop and use maths --- trl/trainer/grpo_trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7d89d2f351f..3ed3d4ef755 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -145,12 +145,11 @@ def __init__(self, dataset, trainer, batch_size, num_generations, min_groups): n_samples = len(dataset) self._n_groups = n_samples // num_generations # Compute exact micro-batch count: each chunk of min_groups yields ceil(chunk_size / batch_size) micro-batches - n_micro = 0 - for chunk_start_g in range(0, self._n_groups, self._min_groups): - chunk_end_g = min(chunk_start_g + self._min_groups, self._n_groups) - chunk_size = (chunk_end_g - chunk_start_g) * num_generations - n_micro += -(-chunk_size // batch_size) # ceil div - self._n_micro_batches = n_micro + n_full_chunks, remainder_groups = divmod(self._n_groups, self._min_groups) + full_chunk_size = self._min_groups * num_generations + self._n_micro_batches = n_full_chunks * -(-full_chunk_size // batch_size) + if remainder_groups: + self._n_micro_batches += -(-(remainder_groups * num_generations) // batch_size) def __len__(self): return self._n_micro_batches