Reproduction
When all completions in a group receive the same reward, advantages are zero:
advantages = rewards - mean_grouped_rewards # = 0 for all completions
So the policy loss contribution is zero, no learning signal, as expected. However, the KL penalty term is not zero:
per_token_loss = per_token_loss + self.beta * per_token_kl # per_token_kl ≠ 0
per_token_kl has gradients through per_token_logps (the current model's log-probs), while ref_per_token_logps is detached (computed under torch.no_grad()). So zero-std groups still generate gradients that push the model toward the reference policy, even though there's no reward signal guiding the direction. KL regularization should only constrain the model when there's actual reward signal to balance against. Without it, KL gradients from zero-std groups act as undirected noise.
This affects both GRPOTrainer and RLOOTrainer (RLOO defaults to beta=0.05).
A simple fix is to zero out completion_mask for zero-std groups, which suppresses both the (already-zero) policy loss and the spurious KL gradients:
is_std_zero_local = is_std_zero[process_slice]
completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int()
This is orthogonal to any future oversampling/dynamic sampling implementation, even if oversampling is added to backfill zero-std groups with new prompts, any groups that exhaust resampling attempts and still have zero std should be masked for the same reason.
Reproducing Code
import torch
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
def constant_reward(completions, **kwargs):
return [1.0] * len(completions)
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
for beta in [0.0, 0.04]:
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=constant_reward,
args=GRPOConfig(
output_dir=f"/tmp/kl_bug_{beta}",
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=8,
beta=beta,
max_steps=3,
report_to="none",
),
train_dataset=dataset,
)
params_before = {n: p.clone() for n, p in trainer.model.named_parameters()}
trainer.train()
changed = sum(
not torch.equal(p, trainer.model.get_parameter(n))
for n, p in params_before.items()
)
print(f"\nbeta={beta}: {changed}/{len(params_before)} params changed")
# beta=0.0 -> 0 params changed (correct: no reward signal, no update)
# beta=0.04 -> ALL params changed (bug: KL gradients update model with no reward signal)
output
beta=0.0: 0/27 params changed
beta=0.04: 27/27 params changed
System Info
- Platform: Linux-6.8.0-1043-nvidia-x86_64-with-glibc2.39
- Python version: 3.12.13
- TRL version: 1.2.0.dev0+573ea22
- PyTorch version: 2.11.0+cu130
- accelerator(s): NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200
- Transformers version: 5.5.3
- Accelerate version: 1.13.0
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 4.8.4
- HF Hub version: 1.11.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.18.9
- Liger-Kernel version: 0.7.0
- PEFT version: not installed
- vLLM version: dev1+gba4a78eb5.cu131
Checklist
Reproduction
When all completions in a group receive the same reward, advantages are zero:
So the policy loss contribution is zero, no learning signal, as expected. However, the KL penalty term is not zero:
per_token_klhas gradients throughper_token_logps(the current model's log-probs), whileref_per_token_logpsis detached (computed undertorch.no_grad()). So zero-std groups still generate gradients that push the model toward the reference policy, even though there's no reward signal guiding the direction. KL regularization should only constrain the model when there's actual reward signal to balance against. Without it, KL gradients from zero-std groups act as undirected noise.This affects both
GRPOTrainerandRLOOTrainer(RLOO defaults tobeta=0.05).A simple fix is to zero out
completion_maskfor zero-std groups, which suppresses both the (already-zero) policy loss and the spurious KL gradients:This is orthogonal to any future oversampling/dynamic sampling implementation, even if oversampling is added to backfill zero-std groups with new prompts, any groups that exhaust resampling attempts and still have zero std should be masked for the same reason.
Reproducing Code
output
System Info
Checklist