Skip to content

Commit 825438d

Browse files
feat(gpt): sft (#65)
1 parent ee3e99f commit 825438d

File tree

9 files changed

+553
-19
lines changed

9 files changed

+553
-19
lines changed

justfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,8 @@ gpt-tokenize:
8181
gpt-train *args:
8282
uv run python toynlp/gpt/train.py {{args}}
8383

84+
gpt-sft *args:
85+
uv run python toynlp/gpt/sft.py {{args}}
86+
8487
gpt-eval:
8588
uv run python toynlp/gpt/evaluation.py

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"huggingface-hub>=0.28.1",
2727
"numpy>=2.2.2",
2828
"pyyaml>=6.0",
29+
"rich>=14.1.0",
2930
"safetensors>=0.5.3",
3031
"tokenizers>=0.21.0",
3132
"torch>=2.5.1",

toynlp/gpt/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,41 @@ Epoch 758/1000 - Train Loss: 0.0003, Train Perplexity: 1.0003, LR: 0.000100,
5959
====================================================================================================
6060
```
6161

62+
## Supervised Fine-Tuning(LoRA)
63+
64+
### The results
65+
66+
```
67+
╭─ Conversation ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
68+
│ Human: Why is the sky blue? │
69+
│ │
70+
│ Assistant: The sky blue is because it is a reflection of the sun's reflection. The sky is a reflection of the sun's reflection. The sun's │
71+
│ reflection is the reflection of the sun's reflection. The sky is a reflection of the sun's reflection. │
72+
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
73+
╭─ Conversation ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
74+
│ Human: What is the capital of France? │
75+
│ │
76+
│ Assistant: The capital of France is Paris. │
77+
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
78+
╭─ Conversation ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
79+
│ Human: What are the three primary colors? │
80+
│ │
81+
│ Assistant: The three primary colors are red, blue, and yellow. │
82+
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
83+
╭─ Conversation ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
84+
│ Human: Tell me a joke about computers. │
85+
│ │
86+
│ Assistant: Why don't computers use the same jokes on the world? Because they're too big! │
87+
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
88+
╭─ Conversation ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
89+
│ Human: Write a three lines poem. │
90+
│ │
91+
│ Assistant: In the stillness of the night, │
92+
│ A peace that never fades, │
93+
│ A peace that never fades. │
94+
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
95+
```
96+
6297

6398
## The mistakes that I made
6499

@@ -132,3 +167,11 @@ Now the mask starts at `(batch, 1, 1, seq_len)` and broadcasting preserves the c
132167
Rows 2 and 3 still attend to the earlier valid tokens, so the logits stay finite and the model trains normally.
133168

134169
**Lessons learned.** Masks are just tensors, so broadcast semantics matter. Printing the exact shapes before and after each operation (or writing a quick unit test) is a cheap way to catch mistakes that otherwise only show up hours into training.
170+
171+
172+
### We don't add a special token for end of sentence
173+
174+
This makes the supervised fine-tuning task harder, because the model has to predict the end of sentence by itself.
175+
176+
For continue the sft, we choose to use `___` as the end of sentence token temporarily.
177+
Now we have added `<eos>` token for the GPT tokenizer and model and retrained the model.

toynlp/gpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class GPTConfig:
1010
# dataset configs
1111
dataset_path: str = "lucadiliello/bookcorpusopen"
1212
dataset_name: str | None = None
13-
batch_size: int = 24 # paper setting: 64
13+
batch_size: int = 8 # paper setting: 64
1414
num_workers: int = 8
1515
shuffle: bool = True
1616
# tokenizer configs

toynlp/gpt/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,7 @@ def _get_mask(self, input_token_ids: torch.Tensor) -> torch.Tensor:
193193
total_params = sum(p.numel() for p in model.parameters())
194194
print(f"Total model parameters: {total_params}")
195195
print(model)
196+
197+
# named modules
198+
for name, module in model.named_modules():
199+
print(name, "->", module)

0 commit comments

Comments
 (0)