Skip to content

Commit ee3e99f

Browse files
fix(tokenizer): add <eos> in tokenizer and sequences (#63)
* fix(tokenizer): add <eos> in tokenizer and sequences * update training result
1 parent aa0526e commit ee3e99f

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

tests/test_gpt_dataset.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from types import SimpleNamespace
2+
import torch
3+
4+
from toynlp.gpt.dataset import split_text_into_contexts
5+
6+
7+
class DummyTokenizer:
8+
def __init__(self) -> None:
9+
self._vocab: dict[str, int] = {"<pad>": 0, "<eos>": 1}
10+
11+
def encode(self, text: str) -> SimpleNamespace:
12+
ids = [self._vocab.setdefault(char, len(self._vocab)) for char in text]
13+
return SimpleNamespace(ids=ids)
14+
15+
def token_to_id(self, token: str) -> int | None:
16+
return self._vocab.get(token)
17+
18+
19+
def test_split_text_includes_eos_and_pads_last_chunk() -> None:
20+
tokenizer = DummyTokenizer()
21+
contexts = split_text_into_contexts(["abcd"], max_length=3, tokenizer=tokenizer)
22+
23+
assert len(contexts) == 2
24+
expected_first = torch.tensor([2, 3, 4], dtype=torch.long)
25+
expected_second = torch.tensor([5, 1, 0], dtype=torch.long)
26+
assert torch.equal(contexts[0], expected_first)
27+
assert torch.equal(contexts[1], expected_second)
28+
29+
30+
def test_split_text_inserts_single_eos_per_document() -> None:
31+
tokenizer = DummyTokenizer()
32+
texts = ["alpha", "<eos>should_be_literal"]
33+
contexts = split_text_into_contexts(texts, max_length=4, tokenizer=tokenizer)
34+
35+
eos_id = tokenizer.token_to_id("<eos>")
36+
stacked = torch.stack(contexts)
37+
eos_count = int((stacked == eos_id).sum().item())
38+
assert eos_count == len(texts)

toynlp/gpt/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Performance comparison:
1818
| Metric | Original GPT | This Implementation |
1919
|:--------:|:---------------:|:-------------------:|
2020
| Perplexity| 18.4 | 24.3|
21-
| SST2 Accuracy | 91.3% | **92.69%** |
21+
| SST2 Accuracy | 91.3% | **92.04%** |
2222

2323

2424
### The dataset is around 800M words(1B tokens)

toynlp/gpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class GPTConfig:
1919
# model configs
2020
vocab_size: int = 40478 # paper: (BPE) vocabulary with 40,478 merges
2121
special_tokens: list[str] = field(
22-
default_factory=lambda: ["<unk>", "<pad>"],
22+
default_factory=lambda: ["<unk>", "<pad>", "<eos>"],
2323
)
2424
# model arch configs
2525
max_seq_length: int = 512 # paper setting: 128, 512

toynlp/gpt/dataset.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,22 @@
99

1010
def split_text_into_contexts(texts: list[str], max_length: int, tokenizer: Tokenizer) -> list[torch.Tensor]:
1111
contexts = []
12-
# print(f"len texts: {len(texts)}")
12+
eos_id = tokenizer.token_to_id("<eos>")
13+
pad_id = tokenizer.token_to_id("<pad>")
14+
if eos_id is None or pad_id is None:
15+
msg = "Missing required special tokens <eos> or <pad> in tokenizer vocabulary"
16+
raise ValueError(msg)
17+
1318
for text in texts:
14-
# print(f"Processing text of length {len(text)}")
1519
token_ids = tokenizer.encode(text).ids
16-
for i in range(len(token_ids) // max_length + 1):
17-
start_idx = i * max_length
18-
end_idx = (i + 1) * max_length
19-
# print(f"i: {i}, start_idx: {start_idx}, end_idx: {end_idx}, len(token_ids): {len(token_ids)}")
20-
if end_idx < len(token_ids):
21-
contexts.append(torch.tensor(token_ids[start_idx:end_idx], dtype=torch.long))
20+
token_ids.append(eos_id)
21+
22+
for start_idx in range(0, len(token_ids), max_length):
23+
chunk = token_ids[start_idx : start_idx + max_length]
24+
if len(chunk) < max_length:
25+
chunk.extend([pad_id] * (max_length - len(chunk)))
26+
contexts.append(torch.tensor(chunk, dtype=torch.long))
27+
2228
return contexts
2329

2430

0 commit comments

Comments
 (0)