Note
This repository accompanies the preprint Learning to Skip the Middle Layers of Transformers (https://arxiv.org/abs/2506.21103). For pre-trained models, see HuggingFace.
We based the underlying Transformer models on the reference implementation of Llama 3 (https://github.com/meta-llama/llama-models/). The key difference relative to Llama 3 is that we used the Sandwich-LN scheme (a.k.a. Peri-LN) instead of Pre-LN. The training codebase is based on the 'nanoGPT speedrun' repository (https://github.com/KellerJordan/modded-nanogpt).
Download the dataset:
uv run data/download_fineweb_10B_gpt2.pyTrain a model:
python -m projects.skip_middle.train_fineweb ...
python -m torch.distributed.run --standalone --nproc_per_node 4 projects/skip_middle/train_fineweb.py ...See help.txt for command-line arguments or config.py for configuration classes.
Install uv:
curl -LsSf https://astral.sh/uv/install.sh | shCreate a virtual environment:
uv venv
source .venv/bin/activateInstall packages:
uv pip install -e .Install PyTorch:
UV_TORCH_BACKEND=auto uv pip install torch