-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbpe_text.py
More file actions
69 lines (51 loc) · 1.96 KB
/
bpe_text.py
File metadata and controls
69 lines (51 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Train a local BPE (ByteLevel) tokenizer on the corpus and encode/decode for the GPT.
Subword units get you much closer to real words than character tokens, with no
paid API. Requires: `pip install tokenizers` (see requirements.txt).
"""
from __future__ import annotations
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
DEFAULT_BPE_VOCAB = 2048
def train_bpe_tokenizer(
text: str,
vocab_size: int = DEFAULT_BPE_VOCAB,
min_frequency: int = 2,
) -> Tokenizer:
"""
ByteLevel BPE: can encode any UTF-8 string; good for English + punctuation.
"""
unk = "<unk>"
tokenizer = Tokenizer(BPE(unk_token=unk))
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=[unk],
)
# One string is fine for multi‑MB text; use an iterator of chunks if OOM.
tokenizer.train_from_iterator([text], trainer=trainer)
return tokenizer
def bpe_tokenizer_path_for_checkpoint(ckpt: Path) -> Path:
return ckpt.with_name(ckpt.stem + ".tokenizer.json")
def save_bpe_tokenizer(tokenizer: Tokenizer, ckpt: Path) -> Path:
p = bpe_tokenizer_path_for_checkpoint(ckpt)
p.parent.mkdir(parents=True, exist_ok=True)
tokenizer.save(str(p))
return p
def load_bpe_tokenizer(ckpt: Path) -> Tokenizer:
p = bpe_tokenizer_path_for_checkpoint(ckpt)
if not p.is_file():
raise FileNotFoundError(
f"BPE tokenizer not found at {p} (expected next to the checkpoint {ckpt.name})"
)
return Tokenizer.from_file(str(p))
def encode_to_ids(tz: Tokenizer, text: str) -> list[int]:
return tz.encode(text).ids
def decode_from_ids(tz: Tokenizer, ids: list[int]) -> str:
return tz.decode(ids)