Skip to content

dankit/language-model-pretraining

Repository files navigation

Language Model Training

451M parameter transformer (416M after weight-tying embeddings) trained on FineWeb-Edu 10B tokens. The model was trained on 8xA100 gpus and after one pass over 10B data, it is starting to show some aspects of learned language. This model size was chosen because I've had a lot of utility with xlm-roberta, a 560m parameter encoder only model. This was to see how the decoder-only aspect could compare. There was a lot of valuable learning in setting up the transformer -> data collection/processing -> training loop -> distributed data parallel -> memory optimizations etc.

Saved checkpoint can be found: https://huggingface.co/dhlak/416m-gpt

The checkpoint is close to chinchilla optimal (currently at ~24x), leaving lots more room for training a smaller language model.

As inference costs dominate, it is common to see companies overtrain smaller models to pack in more knowledge.

Model Total Params Training Tokens Tokens / Param
LFM2-350M 350M 10T ~28,571
LFM2.5-350M 350M 28T ~80,000
Llama 3 8B 8B 15T ~1,875
Llama 3.1 405B 405B 15T ~37
Kimi K2 / K2.5 1T 15.5T ~15.5
DeepSeek V3 671B 14.8T ~22
Qwen3 235B-A22B 235B 36T ~153

Setup

One-liner (installs deps, downloads data, trains):

./launch_training.sh

otherwise, step by step:

1. Prepare Data

pip install -r requirements.txt

Download pre-tokenized dataset (~20GB):

python data_pipeline.py download

Or tokenize from scratch (slower):

python data_pipeline.py prepare

2. Train

Single GPU:

python training.py

Multi-GPU (8x):

torchrun --standalone --nproc_per_node=8 training.py

Checkpoints save to checkpoints/ every 2000 steps.

3. Evaluate

python eval_hellaswag.py --checkpoint checkpoints/checkpoint_step_2000.pt

4. Chat

python chat.py checkpoints/checkpoint_step_2000.pt

Commands: /bench, /compare, /quit

You can find my training metrics in "final_metrics.json". Plotted metrics: Final metrics

Initially I had a few optimization issues and bugs with DDP and my training loop, which is why there was unstable token throughput and MFU early on. I was able to fix a few bugs which can be seen from step 4000 and onwards. After step 4000, the main culprit was checkpointing, which often had reduced token throughput for 1-2 steps after a checkpoint.

Hellaswag evaluation steadily increased overtime, however there may be some scrutiny in the implementation or dataset having leakage.

Credits go to Andrej Karpathy for all his educational content.

About

450M parameter dense transformer (before weight tying) , inspired by gpt and llama papers and trained on 10B tokens of web data

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors