-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathreward_model.py
156 lines (137 loc) · 6.43 KB
/
reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from torch import nn
## Note that the following code is modified from
## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
class RewardModel(nn.Module):
def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
super().__init__()
self.config = base_model.config
self.num_padding_at_beginning = num_padding_at_beginning
if hasattr(self.config, "word_embed_proj_dim"):
# `OPT` models use word_embed_proj_dim as final output
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L497
self.v_head = nn.Linear(self.config.word_embed_proj_dim,
1,
bias=False)
else:
# `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd``
self.config.n_embd = self.config.hidden_size if hasattr(
self.config, "hidden_size") else self.config.n_embd
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.rwtranrsformer = base_model
self.PAD_ID = tokenizer.pad_token_id
def gradient_checkpointing_enable(self):
self.rwtranrsformer.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.rwtranrsformer.gradient_checkpointing_disable()
def forward(self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=False):
loss = None
transformer_outputs = self.rwtranrsformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache)
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
chosen_mean_scores = []
rejected_mean_scores = []
# Split the inputs and rewards into two parts, chosen and rejected
assert len(input_ids.shape) == 2
bs = input_ids.shape[0] // 2
seq_len = input_ids.shape[1]
chosen_ids = input_ids[:bs] # bs x seq x 1
rejected_ids = input_ids[bs:]
chosen_rewards = rewards[:bs]
rejected_rewards = rewards[bs:]
# Compute pairwise loss. Only backprop on the different tokens before padding
loss = 0
for i in range(bs):
chosen_id = chosen_ids[i]
rejected_id = rejected_ids[i]
chosen_reward = chosen_rewards[i]
rejected_reward = rejected_rewards[i]
c_inds = (chosen_id == self.PAD_ID).nonzero()
c_ind = c_inds[self.num_padding_at_beginning].item() if len(
c_inds
) > self.num_padding_at_beginning else seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence
check_divergence = (chosen_id != rejected_id).nonzero()
if len(check_divergence) == 0:
end_ind = rejected_reward.size(-1)
divergence_ind = end_ind - 1
r_ind = c_ind
else:
# Check if there is any padding otherwise take length of sequence
r_inds = (rejected_id == self.PAD_ID).nonzero()
r_ind = r_inds[self.num_padding_at_beginning].item(
) if len(r_inds) > self.num_padding_at_beginning else seq_len
end_ind = max(c_ind, r_ind)
divergence_ind = check_divergence[0]
assert divergence_ind > 0
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]
chosen_mean_scores.append(
chosen_reward[c_ind - 1]) #use the end score for reference
rejected_mean_scores.append(rejected_reward[r_ind - 1])
loss += nn.functional.softplus(
r_truncated_reward - c_truncated_reward).mean()
loss = loss / bs
chosen_mean_scores = torch.stack(chosen_mean_scores)
rejected_mean_scores = torch.stack(rejected_mean_scores)
return {
"loss": loss,
"chosen_mean_scores": chosen_mean_scores,
"rejected_mean_scores": rejected_mean_scores,
}
def forward_value(self,
input_ids=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
return_value_only=False,
prompt_length=0,
use_cache=False):
transformer_outputs = self.rwtranrsformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache)
hidden_states = transformer_outputs[0]
values = self.v_head(hidden_states).squeeze(-1)
if return_value_only:
return values
else:
# [0 0 0 0 prompt, answer, 0 0 0 0 ] for step 3, we have padding at the beginning
# [prompt, answer, 0, 0, 0, 0] this is normal
assert prompt_length > 1, "prompt_length must be greater than 1 to help select the end score"
bs = values.size(0)
seq_len = input_ids.shape[1]
chosen_end_scores = [
] # we use this name for consistency with the original forward function
for i in range(bs):
input_id = input_ids[i]
value = values[i]
c_inds = (input_id[prompt_length:] == self.PAD_ID).nonzero()
# here we only use the answer part of the sequence so we do not need to care about the padding at the beginning
c_ind = c_inds[0].item() + prompt_length if len(
c_inds) > 0 else seq_len
chosen_end_scores.append(value[c_ind - 1])
return {
"values": values,
"chosen_end_scores": torch.stack(chosen_end_scores),
}