forked from NVIDIA-NeMo/RL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreward_functions.py
More file actions
159 lines (135 loc) · 6.78 KB
/
reward_functions.py
File metadata and controls
159 lines (135 loc) · 6.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
NotRequired,
TypedDict,
TypeVar,
)
import torch
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
Tensor = TypeVar("Tensor", bound=torch.Tensor)
class RewardShapingConfig(TypedDict):
"""Configuration for reward function processing.
This configuration enables custom reward shaping, currently supporting DAPO-style
penalties for responses that exceed the maximum response length threshold.
"""
enabled: bool
# The length of the buffer to penalize responses that exceed the maximum response length threshold.
# Responses of length greater than overlong_buffer_length + max_response_length will
# receive the maximum penalty.
overlong_buffer_length: NotRequired[int]
# The penalty for responses that exceed the maximum response length threshold.
overlong_buffer_penalty: NotRequired[float]
# The maximum response length threshold. Responses exceeding this length will be penalized.
max_response_length: NotRequired[int]
# Stop properly penalty: scale factor for rewards of truncated responses (0-1).
# When set to 0, truncated responses get zero reward.
# When set to 1, no penalty is applied (default behavior).
stop_properly_penalty_coef: NotRequired[float | None]
def apply_reward_shaping(
batch: BatchedDataDict, cfg: RewardShapingConfig
) -> BatchedDataDict:
"""Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476.
Nonetheless, it can be potentially extended to support any custom reward logic.
"""
rewards = batch["total_reward"]
if not cfg["enabled"]:
return batch
# Apply stop properly penalty if configured
stop_properly_penalty_coef = cfg.get("stop_properly_penalty_coef", None)
if stop_properly_penalty_coef is not None:
assert 0 <= stop_properly_penalty_coef <= 1, (
f"stop_properly_penalty_coef must be in [0, 1], got {stop_properly_penalty_coef}"
)
# Warn user that DAPO overlong parameters are ignored when stop_properly_penalty_coef is set
ignored_params = []
if cfg.get("overlong_buffer_length") is not None:
ignored_params.append("overlong_buffer_length")
if cfg.get("overlong_buffer_penalty") is not None:
ignored_params.append("overlong_buffer_penalty")
if cfg.get("max_response_length") is not None:
ignored_params.append("max_response_length")
if ignored_params:
print(
f"[WARN] stop_properly_penalty_coef is set, so the following DAPO overlong "
f"parameters are ignored: {', '.join(ignored_params)}. "
f"Set stop_properly_penalty_coef=null to use DAPO overlong reward shaping instead.",
flush=True,
)
truncated = batch.get("truncated")
assert truncated is not None, "truncated field not found in batch"
if isinstance(truncated, list):
truncated = torch.tensor(truncated, dtype=torch.bool, device=rewards.device)
else:
truncated = truncated.to(device=rewards.device)
num_truncated = truncated.sum().item()
if num_truncated > 0:
original_rewards = rewards.clone()
# For truncated samples, scale the reward by stop_properly_penalty_coef
rewards = torch.where(
truncated, rewards * stop_properly_penalty_coef, rewards
)
batch["total_reward"] = rewards
print(
f"[INFO] stop properly penalty applied: {num_truncated}/{len(truncated)} samples truncated, "
f"coef={stop_properly_penalty_coef}, "
f"original_reward_mean={original_rewards[truncated].mean().item():.4f}, "
f"shaped_reward_mean={rewards[truncated].mean().item():.4f}",
flush=True,
)
else:
print(
"[INFO] stop properly penalty: no truncated samples (truncation_rate=0)",
flush=True,
)
return batch
# DAPO reward shaping requires overlong_buffer_length, overlong_buffer_penalty, and max_response_length to be set.
if (
cfg.get("overlong_buffer_length") is None
or cfg.get("overlong_buffer_penalty") is None
or cfg.get("max_response_length") is None
):
raise ValueError(
"Reward function is enabled but only DAPO reward shaping is currently supported. Please ensure overlong_buffer_length, overlong_buffer_penalty, and max_response_length are properly configured."
)
# Get the overlong_buffer_length, overlong_buffer_penalty and max_response_length
overlong_buffer_length = cfg["overlong_buffer_length"]
overlong_buffer_penalty = cfg["overlong_buffer_penalty"]
max_response_length = cfg["max_response_length"]
assert overlong_buffer_penalty >= 0, f"{overlong_buffer_penalty=} must be >=0"
# Calculate the expected response length
expected_response_length = max_response_length - overlong_buffer_length
assert len(batch["message_log"]) == len(rewards), (
"The number of messages in the batch must match the number of rewards"
)
updated_rewards = torch.zeros_like(rewards)
for i, message_log in enumerate(batch["message_log"]):
# Get the assistant response length (index 1 is the assistant response)
message_response_length = None
for message in message_log:
if message["role"] == "assistant":
message_response_length = message["token_ids"].shape[0]
break
assert message_response_length is not None, (
"Assistant response not found during reward shaping"
)
# Calculate the exceed length and the corresponding reward penalty
exceed_length = message_response_length - expected_response_length
overlong_reward = min(
-exceed_length / overlong_buffer_length * overlong_buffer_penalty, 0
)
updated_rewards[i] = rewards[i] + overlong_reward
# Update the rewards in the batch
batch["total_reward"] = updated_rewards
return batch