forked from verl-project/verl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdapo_ray_trainer.py
More file actions
314 lines (262 loc) · 16.1 KB
/
dapo_ray_trainer.py
File metadata and controls
314 lines (262 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from tqdm import tqdm
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask
class RayDAPOTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if "multi_modal_data" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer("step", timing_raw):
# generate a batch
with _timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with _timer("reward", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(new_batch, return_dict=True)
reward_tensor = reward_result["reward_tensor"]
reward_extra_infos_dict = reward_result["reward_extra_info"]
except Exception as e:
print(f"Error in reward_fn: {e}")
reward_tensor = self.reward_fn(new_batch)
reward_extra_infos_dict = {}
new_batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
progress_bar.update(1)
continue
else:
raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
# === Updating ===
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer("adv", timing_raw):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1