-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathppo_trainer.py
More file actions
557 lines (485 loc) · 28.8 KB
/
ppo_trainer.py
File metadata and controls
557 lines (485 loc) · 28.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
DataCollatorWithPadding,
FeatureExtractionMixin,
GenerationConfig,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainerControl,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available
from trl.core import masked_mean, masked_whiten
from trl.models import create_reference_model
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.ppo_config import PPOConfig
from trl.trainer.utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
generate_model_card,
get_comet_experiment_url,
get_reward,
log_table_to_comet_experiment,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)
from trl import PPOTrainer
from trl.trainer.ppo_trainer import INVALID_LOGPROB
from typing import List
from evaluation.math_utils import is_correct
from utils import *
import json
class PPOTrainerWithLengthReward(PPOTrainer):
def init_lens_dict(self, w_lr=1.0, type_lr="cosine", rep_ngram_size=3, rep_penalty=0.0, mode="min"):
self.lens_dict = check_dataset_len_and_init_len_dict(self.train_dataset_len, len(self.eval_dataset), self.accelerator.device, to_long=True)
self.tolerance_ratio = 0.1
self.consider_readability = False
self.w_lr = w_lr
self.accelerator.print("Setting w_lr to {}".format(self.w_lr))
self.compute_lr = None
if type_lr=="cosine":
self.compute_lr = compute_len_reward
elif type_lr=="linear":
self.compute_lr = compute_len_reward_linear
else:
raise Exception("{} is not a valid type for length reward".format(type_lr))
self.rep_ngram_size = rep_ngram_size
self.rep_penalty = rep_penalty
self.accelerator.print("Setting rep_ngram_size to {} and rep_penalty to {}".format(self.rep_ngram_size, self.rep_penalty))
self.mode = mode
if self.mode == "mean":
self.count_dict = check_dataset_len_and_init_len_dict(self.train_dataset_len, len(self.eval_dataset), self.accelerator.device, to_long=True, fill_value=0)
self.accelerator.print("Setting mode to {}".format(self.mode))
self.use_math_verify = True
self.accelerator.print("Setting use_math_verify to {}".format(self.use_math_verify))
def _compute_correctness_reward(self, resp, ground_truth):
is_corr = is_correct(resp, ground_truth, use_math_verify=self.use_math_verify)
return 1 if is_corr else 0
def _gather_len_dict(self):
gathered_len_dict = self.accelerator.gather(self.lens_dict)
assert gathered_len_dict.size(0) == (MAX_TRAIN_SET_SIZE+MAX_VALID_SET_SIZE) * self.accelerator.num_processes
gathered_len_dict = gathered_len_dict.reshape(self.accelerator.num_processes, MAX_TRAIN_SET_SIZE+MAX_VALID_SET_SIZE)
if self.mode == "min":
gathered_len_dict = gathered_len_dict.min(dim=0)[0]
elif self.mode == "mean":
gathered_count_dict = self.accelerator.gather(self.count_dict).reshape(self.accelerator.num_processes, MAX_TRAIN_SET_SIZE+MAX_VALID_SET_SIZE)
total_count = gathered_count_dict.sum(dim=0, keepdim=True)
total_count = torch.maximum(torch.ones_like(total_count), total_count) # Prevent division by 0 error when the total count is 0
gathered_ratio = gathered_count_dict / total_count # total_count is automatically broadcasted along dimension 0
gathered_len_dict = gathered_len_dict * gathered_ratio # If the total count is 0 for a prompt_idx, then the gathered_len_dict will have a value of 0 in that entry
gathered_len_dict = gathered_len_dict.sum(dim=0)
return gathered_len_dict
def _compute_len_reward(self, seq_len, is_corr, prompt_idx, gathered_len_dict):
history_len = gathered_len_dict[prompt_idx].item()
r = self.compute_lr(
history_len=history_len,
seq_len=seq_len,
is_corr=is_corr,
prompt_idx=prompt_idx,
consider_readability=self.consider_readability,
tolerance_ratio=self.tolerance_ratio,
w_lr=self.w_lr
)
return r
def _compute_score(self, response: torch.Tensor, sequence_length: torch.Tensor, ground_truths: List[str], prompt_idxes: List[int], device) -> torch.Tensor:
"""
This methods decoples the score computing from the training method.
Override it to implement your custom reward function.
"""
response_text = self.processing_class.batch_decode(response, skip_special_tokens=True)
assert len(response_text) == sequence_length.size(0) and len(response_text) == len(ground_truths) and len(response_text) == len(prompt_idxes), "response_text:{}; sequence_length: {}; ground_truths: {}; prompt_idxes: {}".format(response_text, sequence_length, ground_truths, prompt_idxes)
scores = torch.zeros(len(response_text), device=device)
# _, score, _ = get_reward(self.reward_model, query_reponse, self.processing_class.pad_token_id, context_length)
curr_batch_proc_corr_lens = {}
for prompt_idx in prompt_idxes:
curr_batch_proc_corr_lens[prompt_idx] = []
gathered_len_dict = self._gather_len_dict()
for i in range(response.size(0)):
correctness_reward = self._compute_correctness_reward(response_text[i], ground_truths[i])
seq_len = int(sequence_length[i].item())
len_reward = self._compute_len_reward(seq_len, correctness_reward >= 1, prompt_idxes[i], gathered_len_dict)
rep_reward = compute_repetition_penalty_reward(resp_tokens=response[i, :seq_len].to(torch.int).tolist(), ngram_size=self.rep_ngram_size, max_penalty=self.rep_penalty, only_start=False)
scores[i] = correctness_reward + len_reward + rep_reward
# Only includes the part of history that the current process accesses
if correctness_reward >= 1:
curr_batch_proc_corr_lens[prompt_idx].append(seq_len)
for prompt_idx in prompt_idxes:
if len(curr_batch_proc_corr_lens[prompt_idx]) == 0:
continue
if self.mode == "min":
self.lens_dict[prompt_idx] = min(min(curr_batch_proc_corr_lens[prompt_idx]), self.lens_dict[prompt_idx].item())
elif self.mode == "mean":
curr_batch_mean = sum(curr_batch_proc_corr_lens[prompt_idx])/len(curr_batch_proc_corr_lens[prompt_idx])
if self.count_dict[prompt_idx].item() <= 0:
assert self.lens_dict[prompt_idx].item() == MAX_LEN
self.lens_dict[prompt_idx] = curr_batch_mean
else:
assert self.lens_dict[prompt_idx].item() < MAX_LEN
ratio = len(curr_batch_proc_corr_lens[prompt_idx]) / (self.count_dict[prompt_idx].item() + len(curr_batch_proc_corr_lens[prompt_idx]))
self.lens_dict[prompt_idx] = ratio * curr_batch_mean + (1-ratio) * self.lens_dict[prompt_idx].item()
self.count_dict[prompt_idx] = self.count_dict[prompt_idx] + len(curr_batch_proc_corr_lens[prompt_idx])
return scores
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_model
# reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
accelerator.print("Train dataset length: {}".format(self.train_dataset_len))
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches * args.num_mini_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model
self.model_wrapped = self.model
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()
if ref_policy is None:
with self.null_ref_context():
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
self.stop_token_id, processing_class.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
score = self._compute_score(postprocessed_response, sequence_length, data["ground_truth"][i : i + args.local_rollout_forward_batch_size], data["prompt_idx"][i : i + args.local_rollout_forward_batch_size], device)
# _, score, _ = get_reward(
# reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
# )
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
values.append(value)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()
gc.collect()
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
# 5. whiten rewards
if args.whiten_rewards:
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
vf_clipfrac
)
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.sum(1).mean()
rlhf_reward = mean_non_score_reward + scores.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
metrics["resp_len"] = self.accelerator.gather_for_metrics(sequence_lengths.float().mean()).mean().item()
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
is_at_epoch_end = self.state.episode % self.train_dataset_len < args.batch_size
if is_at_epoch_end:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
torch.cuda.empty_cache()
del (
query_responses,
responses,
postprocessed_responses,
logprobs,
ref_logprobs,
values,
sequence_lengths,
contain_eos_token,
sequence_lengths_p1,
response_idxs,
padding_mask,
padding_mask_p1,
rewards,
actual_start,
actual_end,
advantages,
returns,
)
torch.cuda.empty_cache()
if is_at_epoch_end:
gathered_len_dict = self._gather_len_dict()
if self.accelerator.is_main_process:
with open(os.path.join(args.output_dir, "history_lens_dict_step_{}.json".format(update)), "w") as tgt_json:
json.dump({i: gathered_len_dict[i].item() for i in range(gathered_len_dict.size(0))}, tgt_json, indent=4)
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
# https://github.com/huggingface/trl/issues/2122
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy
Trainer.save_model(self, output_dir, _internal_call)
self.model = backup_model
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if self.is_deepspeed_enabled:
state_dict = {name.removeprefix('policy.'): param for name, param in state_dict.items()
if name.startswith('policy.')}
super()._save(output_dir, state_dict)