-
Notifications
You must be signed in to change notification settings - Fork 539
Expand file tree
/
Copy pathgrpo_utils.py
More file actions
410 lines (364 loc) · 18.7 KB
/
grpo_utils.py
File metadata and controls
410 lines (364 loc) · 18.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import enum
import os
from dataclasses import dataclass, field
from typing import Literal
import torch
import torch.distributed as dist
from open_instruct import data_types, logger_utils, model_utils
from open_instruct.utils import (
INVALID_LOGPROB,
calibrate_checkpoint_state_dir,
download_latest_checkpoint_from_gs,
get_beaker_whoami,
)
logger = logger_utils.setup_logger(__name__)
TORCH_DTYPES: dict[str, torch.dtype] = {"bfloat16": torch.bfloat16, "float32": torch.float32}
class GRPOLossType(enum.StrEnum):
dapo = "dapo"
cispo = "cispo"
@dataclass
class ExperimentConfig:
# Experiment
exp_name: str = "grpo"
"""The name of this experiment"""
seed: int = 1
"""Seed of the experiment"""
run_name: str | None = None
"""RUNTIME VALUE: A unique name of this run"""
# Optimizer
learning_rate: float = 2e-5
"""The initial learning rate for AdamW optimizer."""
lr_scheduler_type: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "linear"
"""Which scheduler to use"""
warm_up_steps: int = 0
"""Number of warm up steps for the scheduler"""
warmup_ratio: float = 0.0
"""Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)"""
weight_decay: float = 0.0
"""Weight decay for AdamW if we apply some."""
max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
set_weight_decay_on_bias_and_norm: bool = True
"""Whether to set weight decay on bias and norm layers"""
fused_optimizer: bool = False
"""Whether to use fused optimizer"""
# Batch sizes
per_device_train_batch_size: int = 1
"""The forward batch size per device (local_micro_batch_size)"""
total_episodes: int = 100000
"""The total number of episodes in the dataset"""
world_size: int | None = None
"""RUNTIME VALUE: The number of processes (GPUs) to use for training ONLY"""
num_training_steps: int | None = None
"""RUNTIME VALUE: The number of training_steps to train"""
local_eval_every: int = 100
"""Run evaluation after this many training steps. This controls in-loop evals, which reuse the generation/reward verifier setup. Set to -1 to disable."""
save_freq: int = 200
"""How many train steps to save the model"""
backend_timeout: int = 120
"""Timeout for inference/training backends in minutes. Default is 2 hours (120 min)."""
model_dtype: str = "bfloat16"
"""Model dtype for training. Supported values: 'bfloat16', 'float32'."""
# Algorithm
num_epochs: int = 1
"""the number of epochs to train"""
num_mini_batches: int = 1
"""Number of minibatches to split a batch into"""
beta: float = 0.05
"""the beta value of the RLHF objective (KL coefficient)"""
clip_lower: float = 0.2
"""the lower clip range"""
clip_higher: float = 0.272
"""the higher clip range. Sometimes we want this to be higher, see DAPO (https://arxiv.org/abs/2503.14476)"""
truncated_importance_sampling_ratio_cap: float = 2.0
"""The maximum cap for truncated importance sampling ratio (0 means disabled)"""
kl_estimator: Literal[0, 1, 2, 3] = 2
"""the KL estimator to use"""
loss_denominator: str = "token"
"""Optional constant denominator for masked_mean; can be "token" or a float value.
when "token", the loss is divided by the total number of tokens in the batch (standard LM training).
when a float value, the loss is divided by this value (ideally, max tokens in batch, per Dr GRPO).
"""
alpha: float = 0.6
"""The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param)
reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly
used. E.g., [TD3](https://arxiv.org/abs/1802.09477) uses https://github.com/vwxyzjn/cleanrl/blob/dcc289fc6f0bda492fa7360a155262cf826b12a5/cleanrl/td3_continuous_action.py#L269
"""
ref_policy_update_freq: int | None = None
"""How many training steps to take before updating the reference policy."""
load_ref_policy: bool = True
"""Whether to load and use a reference policy for KL penalty calculation."""
loss_fn: GRPOLossType = GRPOLossType.dapo
"""Whether to use DAPO or CISPO loss function."""
record_entropy: bool = False
"""whether to record the entropy of the policy during training. Uses extra memory."""
use_vllm_logprobs: bool = False
"""whether to use vLLM's logprobs for training instead of calculating them via forward pass"""
temperature: float = field(default=1.0, init=False)
"""RUNTIME VALUE: Temperature for sampling, set from streaming_config."""
# Ray
single_gpu_mode: bool = False
"""whether to collocate vLLM and actor on the same node (mostly for debugging purposes)"""
num_learners_per_node: list[int] = field(default_factory=lambda: [1])
"""number of GPU deepspeed learners per node (e.g., --num_learners_per_node 2 4 means 2 learner processes
on the first node and 4 learner processes on the second node; each process will have 1 GPU)"""
num_nodes: int = 1
"""Number of nodes for distributed training."""
sequence_parallel_size: int = 1
"""sequence parallel size - how many GPUs we will parallelize sequences across during training.
Useful for super-long context lengths."""
deepspeed_stage: int = 0
"""the deepspeed stage"""
deepspeed_zpg: int = 8
"""the deepspeed zpg value. Higher values are more memory efficient but slower. Set to 1 to disable zpg, which uses less memory but is significantly slower. Ideally is set to the number of GPUs per node (usually 8, default)."""
deepspeed_offload_param: bool = False
"""whether to offload parameters to CPU (reduces GPU memory usage)"""
deepspeed_offload_optimizer: bool = False
"""whether to offload optimizer states to CPU (reduces GPU memory usage)"""
gather_whole_model: bool = True
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
enable_queue_dashboard: bool = True
"""whether to enable the ActorManager queue monitoring dashboard"""
queue_dashboard_port: int | None = None
"""optional port for the dashboard server (if None, finds a free port automatically)"""
# Experiment tracking
verbose: bool = False
"""If toggled, debug output will be shown"""
with_tracking: bool = False
"""If toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "open_instruct_internal"
"""The wandb's project name"""
wandb_entity: str | None = None
"""The entity (team) of wandb's project"""
push_to_hub: bool = True
"""Whether to upload the saved model to huggingface"""
hf_entity: str | None = None
"""The user or org name of the model repository from the Hugging Face Hub"""
hf_repo_id: str | None = None
"""The id of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_revision: str | None = None
"""The revision of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_url: str | None = None
"""The url of the saved model in the Hugging Face Hub (will be autoset)"""
output_dir: str = "output"
"""Where to save the model"""
cache_dataset_only: bool = False
"""Immediately exit after caching the dataset"""
keep_last_n_checkpoints: int = 3
"""How many checkpoints to keep in the output directory. -1 for all."""
checkpoint_state_freq: int = -1
"""How often to save the model checkpoint, optimizer states, and lr scheduler states (in steps)"""
checkpoint_state_dir: str | None = None
"""Where to save the model checkpoint (if applicable)"""
gs_checkpoint_state_dir: str | None = None
"""The actual `checkpoint_state_dir` to use (handling the case where gs_bucket_path is provided)"""
# Ai2 specific settings
try_launch_beaker_eval_jobs_on_weka: bool = False
"""Whether to launch beaker evaluation jobs after training on weka"""
try_auto_save_to_beaker: bool = True
"""Whether to try to save the model to Beaker dataset `/output` after training"""
gs_bucket_path: str | None = None
"""The path to the gs bucket to save the model to"""
oe_eval_tasks: list[str] | None = None
"""The beaker evaluation tasks to launch"""
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""
oe_eval_beaker_image: str | None = None
"""the docker image for evaluation for oe-eval"""
oe_eval_gpu_multiplier: int | None = None
"""multiply the gpus used for each oe-eval task"""
eval_priority: Literal["low", "normal", "high", "urgent"] = "normal"
"""the priority of auto-launched evaluation jobs"""
eval_workspace: str = "ai2/tulu-3-results"
"""the workspace to launch evaluation jobs on"""
send_slack_alerts: bool = False
"""Whether to send Slack alerts on training failures"""
# Evaluation behavior
eval_on_step_0: bool = False
"""Whether to run local evaluation at training step 0. Defaults to False."""
def __post_init__(self):
if self.send_slack_alerts and not os.environ.get("SLACK_WEBHOOK_URL"):
logger.warning(
"--send_slack_alerts is set but SLACK_WEBHOOK_URL is not in the environment. Slack alerts will not be sent."
)
if self.use_vllm_logprobs and self.truncated_importance_sampling_ratio_cap > 0.0:
raise ValueError(
"Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
)
if self.loss_denominator != "token" and float(self.loss_denominator) <= 0:
raise ValueError(
f"loss_denominator must be a valid float greater than 0 if not 'token', got: {self.loss_denominator}"
)
if self.checkpoint_state_freq > 0 and self.checkpoint_state_dir is None:
raise ValueError("`checkpoint_state_dir` must be provided if `checkpoint_state_freq` is greater than 0!")
if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1:
raise ValueError("`checkpoint_state_freq` must be greater than 0 if `checkpoint_state_dir` is provided!")
if self.gs_checkpoint_state_dir is not None and not self.gs_checkpoint_state_dir.startswith("gs://"):
raise ValueError(f"`gs_checkpoint_state_dir` must start with 'gs://', got: {self.gs_checkpoint_state_dir}")
if self.eval_on_step_0 and self.local_eval_every <= 0:
raise ValueError(
"`eval_on_step_0` requires `local_eval_every` > 0. "
"Set `local_eval_every` to a positive value or disable `eval_on_step_0`."
)
if self.gs_bucket_path is not None and not self.gs_bucket_path.startswith("gs://"):
raise ValueError(f"`gs_bucket_path` must start with 'gs://', got: {self.gs_bucket_path}")
if self.sequence_parallel_size > 1 and self.deepspeed_stage != 3:
raise ValueError("`sequence_parallel_size` > 1 requires `deepspeed_stage` to be 3!")
if self.gs_bucket_path is not None and self.gs_checkpoint_state_dir is None:
if self.checkpoint_state_dir is None:
raise ValueError("`checkpoint_state_dir` must be provided when using `gs_bucket_path`!")
checkpoint_dir_name = self.checkpoint_state_dir.rstrip("/")
beaker_users = get_beaker_whoami()
if beaker_users is not None:
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{beaker_users}/{checkpoint_dir_name}"
else:
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}"
if not checkpoint_dir_name.startswith("/filestore"):
self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}"
if self.checkpoint_state_dir is not None:
if self.gs_checkpoint_state_dir is not None:
download_latest_checkpoint_from_gs(self.gs_checkpoint_state_dir, self.checkpoint_state_dir)
calibrate_checkpoint_state_dir(self.checkpoint_state_dir)
if not self.load_ref_policy and self.beta != 0.0:
raise ValueError(
"When load_ref_policy=False, beta must be 0.0. "
f"Got beta={self.beta}. Set --beta 0.0 or --load_ref_policy to use KL penalty."
)
def compute_grpo_loss(
new_logprobs: torch.Tensor,
ratio: torch.Tensor,
advantages: torch.Tensor,
ref_logprobs: torch.Tensor | None,
config: ExperimentConfig,
tis_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if config.loss_fn == GRPOLossType.dapo:
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - config.clip_lower, 1.0 + config.clip_higher)
elif config.loss_fn == GRPOLossType.cispo:
# cispo: directly clip ratio, no lower bound.
# reinforce loss, so multiply by new logprobs
pg_losses = -advantages * torch.clamp(ratio.detach(), max=1.0 + config.clip_higher) * new_logprobs
pg_losses2 = pg_losses
else:
raise ValueError(f"Invalid loss function: {config.loss_fn}")
if tis_weights is not None:
pg_losses = pg_losses * tis_weights
pg_losses2 = pg_losses2 * tis_weights
pg_loss_max = torch.max(pg_losses, pg_losses2)
if ref_logprobs is not None:
# We want the KL loss to backpropagate through the model.
# We also clamp the KL loss to avoid numerical instability.
# https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae
ref_logprobs_diff = (new_logprobs - ref_logprobs).clamp(-40.0, 40.0)
kl_all = model_utils.estimate_kl(ref_logprobs_diff, ratio)
kl = kl_all[config.kl_estimator]
else:
kl = torch.zeros_like(pg_loss_max)
return pg_losses, pg_losses2, pg_loss_max, kl
def forward_for_logprobs(
model: torch.nn.Module,
query_responses: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
pad_token_id: int,
temperature: float,
return_entropy: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward pass to compute log probabilities."""
output = model(input_ids=query_responses, attention_mask=attention_mask, position_ids=position_ids)
logits = getattr(output, "logits", output)
logits = logits / temperature
# The logits at position i predict token i+1, so we align them with labels shifted by 1
logits = logits[:, :-1]
labels = query_responses[:, 1:].clone().to(logits.device)
# Replace pad tokens with 0 to avoid index out of bounds errors in gather
labels[labels == pad_token_id] = 0
logprob_BT = model_utils.log_softmax_and_gather(logits, labels)
# For now, entropy is just for monitoring, and we don't pass gradients through it.
entropy = None
if return_entropy:
with torch.no_grad():
entropy = model_utils.entropy_from_logits(logits)
return logprob_BT, entropy
def compute_logprobs(
model: torch.nn.Module,
data_BT: data_types.CollatedBatchData,
pad_token_id: int,
temperature: float,
use_grad: bool = False,
batch_size: int | None = None,
) -> list[torch.Tensor]:
"""Compute log probabilities for all samples in batch."""
logprobs_BT: list[torch.Tensor] = []
num_samples = len(data_BT.query_responses)
if batch_size is None:
batch_size = 1
context = torch.enable_grad() if use_grad else torch.no_grad()
with context:
for start_idx in range(0, num_samples, batch_size):
end_idx = min(start_idx + batch_size, num_samples)
batch_indices = list(range(start_idx, end_idx))
query_responses = [data_BT.query_responses[i] for i in batch_indices]
attention_masks = [data_BT.attention_masks[i] for i in batch_indices]
position_ids = [data_BT.position_ids[i] for i in batch_indices]
shapes = [tuple(t.shape) for t in query_responses]
if len(set(shapes)) != 1:
for i in batch_indices:
single_logprobs, _ = forward_for_logprobs(
model,
data_BT.query_responses[i],
data_BT.attention_masks[i],
data_BT.position_ids[i],
pad_token_id,
temperature,
False,
)
response_mask_BT = data_BT.response_masks[i].to(single_logprobs.device)
single_logprobs = torch.masked_fill(
single_logprobs, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB
)
logprobs_BT.append(single_logprobs)
continue
batch_query_responses = torch.cat(query_responses, dim=0)
batch_attention_masks = torch.cat(attention_masks, dim=0)
batch_position_ids = torch.cat(position_ids, dim=0)
batch_logprobs, _ = forward_for_logprobs(
model,
batch_query_responses,
batch_attention_masks,
batch_position_ids,
pad_token_id,
temperature,
False,
)
sample_sizes = [data_BT.query_responses[i].shape[0] for i in batch_indices]
split_logprobs = torch.split(batch_logprobs, sample_sizes, dim=0)
for i, logprob_BT in zip(batch_indices, split_logprobs):
response_mask_BT = data_BT.response_masks[i].to(logprob_BT.device)
logprob_BT = torch.masked_fill(logprob_BT, ~response_mask_BT[:, 1:].bool(), INVALID_LOGPROB)
logprobs_BT.append(logprob_BT)
return logprobs_BT
def calculate_token_counts(
accumulation_steps: int,
data_BT: data_types.CollatedBatchData,
device: torch.device,
process_group: dist.ProcessGroup | None = None,
) -> dict[int, float]:
"""Compute total token counts per accumulation group, all-reduced across DP ranks.
Copied from grpo_fast.py to share logic with olmo_core_train_modules.py.
"""
accumulation_counts: dict[int, float] = {}
local_counts = [mask[:, 1:].sum().float() for mask in data_BT.response_masks]
if not local_counts:
return accumulation_counts
counts_tensor = torch.stack(local_counts).to(device)
dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM, group=process_group)
for i, count in enumerate(counts_tensor):
group_idx = i // accumulation_steps
key = int(group_idx * accumulation_steps)
accumulation_counts[key] = accumulation_counts.get(key, 0.0) + count.item()
return accumulation_counts