-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathtest_tokenizer_special_tokens.py
More file actions
52 lines (42 loc) · 1.77 KB
/
test_tokenizer_special_tokens.py
File metadata and controls
52 lines (42 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
from transformers import AutoTokenizer
import transformer_lens.loading_from_pretrained as loading
from transformer_lens import HookedTransformer, HookedTransformerConfig
# Small models for basic testing
PUBLIC_MODEL_TESTING_LIST = ["gpt2-small", "opt-125m", "pythia-70m"]
# Full set of models to test when HF_TOKEN is available
FULL_MODEL_TESTING_LIST = [
"solu-1l",
"gpt2-small",
"gpt-neo-125M",
"opt-125m",
"opt-30b",
"stanford-gpt2-small-a",
"pythia-70m",
]
# Use full model list if HF_TOKEN is available, otherwise use public models only
MODEL_TESTING_LIST = (
FULL_MODEL_TESTING_LIST if os.environ.get("HF_TOKEN", "") else PUBLIC_MODEL_TESTING_LIST
)
def test_d_vocab_from_tokenizer():
cfg = HookedTransformerConfig(
n_layers=1, d_mlp=10, d_model=10, d_head=5, n_heads=2, n_ctx=20, act_fn="relu"
)
test_string = "a fish."
# Test tokenizers for different models
for model_name in MODEL_TESTING_LIST:
if model_name == "solu-1l":
tokenizer_name = "NeelNanda/gpt-neox-tokenizer-digits"
else:
tokenizer_name = loading.get_official_model_name(model_name)
model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name))
tokens_with_bos = model.to_tokens(test_string)
tokens_without_bos = model.to_tokens(test_string, prepend_bos=False)
# Check that the lengths are different by one
assert (
tokens_with_bos.shape[-1] == tokens_without_bos.shape[-1] + 1
), "BOS Token not added when expected"
# Check that we don't have BOS when we disable the flag
assert (
tokens_without_bos.squeeze()[0] != model.tokenizer.bos_token_id
), "BOS token is present when it shouldn't be"