Skip to content

Commit 0e6f7a4

Browse files
committed
[Feature] Aggregation strategies
ghstack-source-id: 5ce7a2a Pull-Request: #3209
1 parent fd3964f commit 0e6f7a4

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class GRPOLoss(LossModule):
8686
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
8787
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
8888
(see table and description; enables per-token trust region).
89+
aggregation (str, optional): loss aggregation strategy for the policy objective.
90+
- "token_mean": global masked token mean (weights long sequences more). Default.
91+
- "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
92+
- "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
8993
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
9094
loss to favour exploratory policies.
9195
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -147,6 +151,7 @@ def __init__(
147151
*,
148152
clip_epsilon: float | tuple[float, float] = 0.2,
149153
kl_mask_threshold: float | None = None,
154+
aggregation: str | None = "token_mean",
150155
entropy_bonus: bool = True,
151156
samples_mc_entropy: int = 1,
152157
entropy_coeff: float = 0.01,
@@ -167,6 +172,7 @@ def __init__(
167172
self.entropy_coeff = entropy_coeff
168173
self.reduction = reduction if reduction is not None else "mean"
169174
self.kl_mask_threshold = kl_mask_threshold
175+
self.aggregation = aggregation or "token_mean"
170176

171177
# Determine device and register clip epsilon as buffer
172178
if device is None:
@@ -397,13 +403,13 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
397403
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
398404

399405
td_out.set("ESS", _reduce(ess / batch, self.reduction))
400-
td_out = td_out.named_apply(
401-
lambda name, value: _reduce(
402-
value, reduction=self.reduction, mask=mask
403-
).squeeze(-1)
404-
if name.startswith("loss_")
405-
else value,
406-
)
406+
# Aggregate loss terms according to aggregation strategy
407+
for key in list(td_out.keys()):
408+
if isinstance(key, tuple) or not isinstance(key, str):
409+
continue
410+
if key.startswith("loss_"):
411+
val = td_out.get(key)
412+
td_out.set(key, self._aggregate_loss_value(val, mask))
407413
if self.kl_to_ref_coeff is not None and self.kl_to_ref_coeff > 0:
408414
# FIXME: parameterize this
409415
loss_kl, kl_penalty = self._kl_to_ref(
@@ -447,6 +453,34 @@ def _compute_policy_objective(
447453
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
448454
return -gain, clip_fraction
449455

456+
def _aggregate_loss_value(
457+
self, value: torch.Tensor, mask: torch.Tensor
458+
) -> torch.Tensor:
459+
"""Aggregate a per-token loss tensor using the configured strategy.
460+
461+
Supports:
462+
- token_mean: masked mean across all tokens (default)
463+
- prompt_mean: per-sample masked mean over tokens, then mean across batch
464+
- none: return per-token loss with masked-out tokens set to 0
465+
466+
The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
467+
and `mask` has shape [..., T].
468+
"""
469+
if self.aggregation == "none" or self.reduction == "none":
470+
mask_exp = expand_as_right(mask, value)
471+
return torch.where(mask_exp, value, value.new_zeros(()).expand_as(value))
472+
473+
if self.aggregation == "prompt_mean":
474+
# Mean over valid tokens per sample, then mean across batch
475+
mask_exp = expand_as_right(mask, value).to(value.dtype)
476+
token_sum = (value * mask_exp).sum(dim=-2, keepdim=False)
477+
token_count = mask_exp.sum(dim=-2, keepdim=False).clamp_min(1.0)
478+
sample_mean = token_sum / token_count
479+
return sample_mean.mean(dim=0, keepdim=False)
480+
481+
# token_mean (global masked mean)
482+
return _reduce(value, reduction="mean", mask=mask).squeeze(-1)
483+
450484
def _get_entropy(
451485
self, dist: d.Distribution, adv_shape: torch.Size
452486
) -> torch.Tensor | TensorDict:

0 commit comments

Comments
 (0)