Skip to content

[Bug] Zero-std reward groups produce spurious KL gradients when beta > 0 in GRPO/RLOO #5588

@SwayamInSync

Description

@SwayamInSync

Reproduction

When all completions in a group receive the same reward, advantages are zero:

advantages = rewards - mean_grouped_rewards  # = 0 for all completions

So the policy loss contribution is zero, no learning signal, as expected. However, the KL penalty term is not zero:

per_token_loss = per_token_loss + self.beta * per_token_kl  # per_token_kl ≠ 0

per_token_kl has gradients through per_token_logps (the current model's log-probs), while ref_per_token_logps is detached (computed under torch.no_grad()). So zero-std groups still generate gradients that push the model toward the reference policy, even though there's no reward signal guiding the direction. KL regularization should only constrain the model when there's actual reward signal to balance against. Without it, KL gradients from zero-std groups act as undirected noise.

This affects both GRPOTrainer and RLOOTrainer (RLOO defaults to beta=0.05).

A simple fix is to zero out completion_mask for zero-std groups, which suppresses both the (already-zero) policy loss and the spurious KL gradients:

is_std_zero_local = is_std_zero[process_slice]
completion_mask = completion_mask * (~is_std_zero_local).unsqueeze(1).int()

This is orthogonal to any future oversampling/dynamic sampling implementation, even if oversampling is added to backfill zero-std groups with new prompts, any groups that exhaust resampling attempts and still have zero std should be masked for the same reason.

Reproducing Code

import torch
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

def constant_reward(completions, **kwargs):
    return [1.0] * len(completions)

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

for beta in [0.0, 0.04]:
    trainer = GRPOTrainer(
        model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
        reward_funcs=constant_reward,
        args=GRPOConfig(
            output_dir=f"/tmp/kl_bug_{beta}",
            per_device_train_batch_size=3,
            num_generations=3,
            max_completion_length=8,
            beta=beta,
            max_steps=3,
            report_to="none",
        ),
        train_dataset=dataset,
    )

    params_before = {n: p.clone() for n, p in trainer.model.named_parameters()}
    trainer.train()
    changed = sum(
        not torch.equal(p, trainer.model.get_parameter(n))
        for n, p in params_before.items()
    )
    print(f"\nbeta={beta}: {changed}/{len(params_before)} params changed")
    # beta=0.0 -> 0 params changed (correct: no reward signal, no update)
    # beta=0.04 -> ALL params changed (bug: KL gradients update model with no reward signal)

output

beta=0.0: 0/27 params changed
beta=0.04: 27/27 params changed

System Info

- Platform: Linux-6.8.0-1043-nvidia-x86_64-with-glibc2.39
- Python version: 3.12.13
- TRL version: 1.2.0.dev0+573ea22
- PyTorch version: 2.11.0+cu130
- accelerator(s): NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200, NVIDIA B200
- Transformers version: 5.5.3
- Accelerate version: 1.13.0
- Accelerate config: 
  - compute_environment: LOCAL_MACHINE
  - distributed_type: DEEPSPEED
  - mixed_precision: bf16
  - use_cpu: False
  - debug: False
  - num_processes: 8
  - machine_rank: 0
  - num_machines: 1
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - enable_cpu_affinity: False
  - deepspeed_config: {'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero_stage': 2}
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []
- Datasets version: 4.8.4
- HF Hub version: 1.11.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.18.9
- Liger-Kernel version: 0.7.0
- PEFT version: not installed
- vLLM version: dev1+gba4a78eb5.cu131

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions