forked from modelscope/ms-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdpo_trainer.py
More file actions
95 lines (84 loc) · 4.23 KB
/
dpo_trainer.py
File metadata and controls
95 lines (84 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
from collections import namedtuple
from functools import partial
from megatron.core import mpu
from torch.distributed.nn import all_reduce
from swift.rlhf_trainers import DPOTrainer
from swift.utils import get_current_device, get_logger
from .rlhf_mixin import MegatronRLHFTrainer
logger = get_logger()
class DummyDPOTrainer(DPOTrainer):
# For reusing the dpo_loss function implemented in Swift's DPOTrainer.
def __init__(self, args):
self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device())
self.f_alpha_divergence_coef = 1.
self.f_divergence_params = {'alpha_divergence_coef': self.f_alpha_divergence_coef}
self.reference_free = args.reference_free
self.label_smoothing = args.label_smoothing
self.f_divergence_type = args.f_divergence_type
self.loss_type = args.loss_type
self.beta = args.beta
class MegatronDPOTrainer(MegatronRLHFTrainer):
def __init__(self, args, template):
super().__init__(args, template)
self.dummy_dpo_trainer = DummyDPOTrainer(args)
def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params):
ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach()
output_tensor = output_tensor[output_tensor.shape[0] // 2:]
args = self.args
num_samples = labels.shape[0] // 2 if packed_seq_params is None else packed_seq_params.num_samples
logps = self.get_logps(output_tensor, labels, packed_seq_params, num_samples * 2)
ref_logps = self.get_logps(ref_output_tensor, labels, packed_seq_params, num_samples * 2)
loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss(
logps[:num_samples],
logps[num_samples:],
ref_logps[:num_samples],
ref_logps[num_samples:],
)
if args.rpo_alpha:
loss_mask = labels != -100
if args.padding_free:
num_tokens = packed_seq_params.cu_seqlens_q[num_samples] // args.context_parallel_size
loss_mask[:, num_tokens:] = 0
else:
loss_mask[num_samples:] = 0
nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]])
if args.context_parallel_size > 1:
nll_loss = all_reduce(nll_loss, group=mpu.get_context_parallel_group())
nll_loss = nll_loss[0] / nll_loss[1]
loss = loss + args.rpo_alpha * nll_loss
loss = loss.mean()
metric = {
'loss': loss.detach().clone(),
'logps/chosen': logps[:num_samples].mean(),
'logps/rejected': logps[num_samples:].mean(),
'rewards/chosen': chosen_rewards.mean(),
'rewards/rejected': rejected_rewards.mean(),
'rewards/accuracies': (chosen_rewards > rejected_rewards).float().mean(),
'rewards/margins': (chosen_rewards - rejected_rewards).mean(),
}
if args.rpo_alpha:
metric['nll_loss'] = nll_loss.detach()
metric = self._all_reduce_metric(metric)
# fix megatron-lm bug
loss = loss / mpu.get_context_parallel_world_size()
return loss, metric
def forward_step(self, data_iterator, model):
# Get the batch.
unwrapped_model = model.module.module
input_tensor = unwrapped_model.get_input_tensor()
vp_stage = unwrapped_model.vp_stage
data = self.get_batch(data_iterator, vp_stage)
data.pop('loss_scale', None)
# ref_model
with torch.no_grad(), self.null_ref_context() as ref_models:
ref_model = ref_models[vp_stage or 0]
if input_tensor is not None:
ref_model.set_input_tensor(input_tensor[:input_tensor.shape[0] // 2].detach())
ref_output_tensor = ref_model(**data)
if input_tensor is not None:
unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:])
output_tensor = model(**data)
return torch.concat([ref_output_tensor, output_tensor], dim=0), partial(
self.loss_func, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params'))