Skip to content

Commit fcb951c

Browse files
authored
Prevent special tokens from being added into yes no tokens, improve code readability with minor yes no ordering bug (#645)
1 parent f98ccf4 commit fcb951c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

docs/lm_head_to_classifier/convert_lm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def as_no_id_yes_id(
7676
yes: str = "1",
7777
) -> tuple[int, int]:
7878
tokenizer = AutoTokenizer.from_pretrained(model_name)
79-
no_id_yes_id = [tokenizer(no).input_ids, tokenizer(yes).input_ids]
79+
no_id_yes_id = [tokenizer(no, add_special_tokens=False).input_ids, tokenizer(yes, add_special_tokens=False).input_ids]
8080
assert len(no_id_yes_id[0]) == 1
8181
assert len(no_id_yes_id[1]) == 1
8282
return no_id_yes_id[0][0], no_id_yes_id[1][0]
@@ -87,7 +87,7 @@ def only_yes_id(
8787
) -> tuple[int]:
8888
"""Get the id of the yes token."""
8989
tokenizer = AutoTokenizer.from_pretrained(model_name)
90-
yes_id = tokenizer(yes).input_ids
90+
yes_id = tokenizer(yes, add_special_tokens=False).input_ids
9191
assert len(yes_id) == 1
9292
return (yes_id[0],)
9393

@@ -102,7 +102,7 @@ def upload_and_convert(
102102
if not uses_no_and_yes:
103103
no_id_yes_id = only_yes_id(model_name, yes)
104104
else:
105-
no_id_yes_id = as_no_id_yes_id(model_name, yes, no)
105+
no_id_yes_id = as_no_id_yes_id(model_name, no, yes)
106106
split_name = model_name.split("/")[1]
107107
model_cls = convert_to_sequence_classifier(f"{model_name}", no_id_yes_id)
108108
model_cls = model_cls.to(torch.float16)

0 commit comments

Comments
 (0)