-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
Hello, thanks for the nice reference code! I noticed the following code tries to match the response tokens, but it might match the instruction tokens instead
Lines 60 to 63 in aaa0ecb
| response_token_ids_start_idx = None | |
| for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: | |
| response_token_ids_start_idx = idx | |
| break |
This is because it breaks when the first token matches, but '### Response:\n' is encoded with [21017, 18261, 25, 198]., but it matches ### Instruction:\n ([21017, 46486, 25, 198]) instead.
To resolve the issue and if it is indeed that you intended to match the response tokens, you should consider the following snippet instead :)
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if response_token_ids == examples[i]["input_ids"][idx:idx+len(response_token_ids)]:
response_token_ids_start_idx = idx
Our related issue huggingface/trl#445 (comment)
Metadata
Metadata
Assignees
Labels
No labels