Yuxiang Chen, Yifan Liu, Xiaoming Xu, Pengle Zhang, Michael Beyer, Martin Rapp, Jun Zhu, Jianfei Chen
Tsinghua University & Bosch AI Research
This repository contains the official implementation of TetraJet-v2, a method for accurate NVFP4 training for large language models with oscillation suppression and outlier control.
-
Paper (TetraJet-v2): arXiv, OpenReview
-
Status:
- 🎉 (2026/05) This work has been accepted as a Spotlight paper in ICML 2026.
- (2026/05) We released an updated version of TetraJet-v2 with kernels and the training recipe.
- (2025/10) We released the first version of TetraJet-v2 on arXiv.
-
Previous work (TetraJet-MXFP4Training, ICML 2025): arXiv, code
This work (TetraJet-v2) extends our prior low-bit training efforts (TetraJet) from MXFP4 for ViTs to more accurate and robust NVFP4 training for LLMs.
- Practically optimal NVFP4 linear recipe: an end-to-end FP4 training recipe with double-block quantization, aligned activations for correct gradient estimation, and the best backward RHT setting for LLM linear layers.
- ⭐ OsciReset (key algorithmic contribution): a lightweight weight-oscillation suppression algorithm that identifies unstable FP4 weights and resets their master weights to quantization-bin centers. It improves weight-optimization stability during annealing and convergence in large-data, long-horizon low-precision training, and can also transfer to Quantization-Aware Training (QAT) for producing low-precision weights.
- OutControl: an outlier-control recipe that combines backward RHT and mixed FP4+MXFP8 outlier-channel retention for more accurate activation and gradient computation.
TetraJet-v2 improves FP4 pre-training on OLMo2 models up to 370M parameters and reduces the average gap to BF16 by 51.3% over prior FP4 methods, while providing end-to-end speedups over FP8 baselines.
olmo2-training/: OLMo2 training code based on allenai/OLMo, with files not needed for training removed.- Main OLMo changes are in
olmo/config.py,olmo/model.py,olmo/train.py,scripts/train.py, plus checkpoint/initialization compatibility for quantized layer buffers. - NVFP4 linear layers are implemented in
olmo/quantization_real/linear.py. - Mixed NVFP4+MXFP8 outlier-channel training is implemented in
olmo/quantization_real/linear_mix.pyand scheduled byolmo/quantization_real/calibrate.py. - Oscillation reset algorithm is implemented in
olmo/quantization_real/oscillation_reset.pyandolmo/quantization_real/oscillation_reset_memeff.py.
- Main OLMo changes are in
kernels/: TetraJet-v2 NVFP4 kernels.scripts/: local and SLURM launch scripts for OLMo2 training.
- NVIDIA Blackwell GPU. The TetraJet-v2 kernels were designed for RTX 5090 / RTX PRO 6000.
- CUDA >= 12.8
- FlashAttention 2.
conda create -y -n tjv2-nvfp4 python=3.12 pip
conda activate tjv2-nvfp4
# Install OLMo and training dependencies.
cd olmo2-training
pip install -e ".[train]"
cd ..
# Install TetraJet-v2 kernels.
cd kernels
pip install -e . --no-build-isolationPossible Issues:
- Use
--no-build-isolationwhen installing CUDA extension packages after PyTorch is installed. - If using a prebuilt FlashAttention wheel, make sure its Python, PyTorch, CUDA, and CXX11 ABI tags match your environment.
- Limit CUDA/C++ build parallelism if needed:
export MAX_JOBS=4.
The training configs expect OLMo2 preprocessed .npy data. You can use the train and eval/perplexity file lists in the official OLMo config OLMo2-7B-stage1.yaml to download the preprocessed files directly.
After downloading, replace both the training data prefix and the eval data paths in olmo2-training/configs/*/*.yaml with your local data directories.
Run launch scripts from scripts/:
cd scripts
# Local, 1 node, 8 GPUs.
./local_70m_8gpu.sh TJv2-mix_fp8-osci_reset-mem_eff
# SLURM, 1 node, 8 GPUs.
sbatch slurm_70m_1node.sh TJv2-mix_fp8-osci_reset-mem_eff
# SLURM, 2 nodes, 8 GPUs per node.
sbatch slurm_70m_2nodes.sh TJv2-mix_fp8-osci_reset-mem_effUse the corresponding 70m, 150m, or 370m script for each model size. Available config names:
bf16
TJv2-base
TJv2-mix_fp8
TJv2-mix_fp8-osci_reset
TJv2-mix_fp8-osci_reset-mem_eff
If no config is provided, scripts default to bf16. Pass a checkpoint path as the second argument to resume:
./local_70m_8gpu.sh TJv2-base /path/to/checkpointOutputs are saved to olmo2-training/outputs/<model_size>/<config_name>. W&B runs offline by default; change WANDB_MODE in scripts/common.sh to sync online.
See
kernels/README.md.
This repository is released under the Apache License 2.0.
The olmo2-training/ directory contains code adapted from
allenai/OLMo, which is also licensed under
Apache License 2.0. We retain the upstream license notice and document the
TetraJet-v2 modifications in NOTICE.
If you find this work useful, please consider citing:
@article{chen2025tetrajet,
title={Tetrajet-v2: Accurate nvfp4 training for large language models with oscillation suppression and outlier control},
author={Chen, Yuxiang and Liu, Yifan and Xu, Xiaoming and Zhang, Pengle and Beyer, Michael and Rapp, Martin and Zhu, Jun and Chen, Jianfei},
journal={arXiv preprint arXiv:2510.27527},
year={2025}
}