Skip to content

Matching the response tokens #197

@vwxyzjn

Description

@vwxyzjn

Hello, thanks for the nice reference code! I noticed the following code tries to match the response tokens, but it might match the instruction tokens instead

response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

This is because it breaks when the first token matches, but '### Response:\n' is encoded with [21017, 18261, 25, 198]., but it matches ### Instruction:\n ([21017, 46486, 25, 198]) instead.

To resolve the issue and if it is indeed that you intended to match the response tokens, you should consider the following snippet instead :)

            for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
                # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                if response_token_ids == examples[i]["input_ids"][idx:idx+len(response_token_ids)]:
                    response_token_ids_start_idx = idx  

Our related issue huggingface/trl#445 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions