Code for A Scalable Measure of Loss Landscape Curvature for Analyzing the Training Dynamics of LLMs.
Dayal Singh Kalra, Jean-Christophe Gagnon-Audet, Andrey Gromov, Ishita Mediratta, Kelvin Niu, Alexander H Miller, Michael Shvartsman
We analyze critical sharpness, a measure of loss landscape curvature that can be computed using forward passes alone. This makes it tractable for large models where Hessian-based methods are infeasible. The method requires approximately 5-6 forward passes given an update direction.
conda create -n curvature python=3.10
conda activate curvature
pip install torch numpy pandas scipy wandb
pip install cupy-cuda12x # adjust for your CUDA version -- needed for HVP computationCIFAR-10: Downloaded automatically via torchvision.
FineWeb: Follow the nanoGPT data preparation instructions to tokenize FineWeb-Edu into language/data/fineweb/.
FCN with 4 layers, width 512, trained with SGD at constant learning rate 3e-02.
cd image
# Full batch (batch_size=50000)
python train_fcn_image_sgd_dir_sharp.py \
--batch_size 50000 --lr_peak 3e-02 --num_steps 10000 \
--warmup_steps 0 --stable_steps 10000 --loss_name xent
# Batch size 5000
python train_fcn_image_sgd_dir_sharp.py \
--batch_size 5000 --lr_peak 3e-02 --num_steps 10000 \
--warmup_steps 0 --stable_steps 10000 --loss_name xent
# Batch size 500
python train_fcn_image_sgd_dir_sharp.py \
--batch_size 500 --lr_peak 3e-02 --num_steps 10000 \
--warmup_steps 0 --stable_steps 10000 --loss_name xentGPT with 12 layers, 768 embedding dim, trained on FineWeb with WSD schedule.
cd language
python train_gpt_adam_dir_sharp.py \
--num_layers 12 --num_heads 12 --head_dim 64 \
--lr_peak 3e-04 --num_steps 10000 \
--warmup_steps 2000 --stable_steps 6000 \
--batch_size 16 --gradient_accumulation_steps 64 \
--weight_decay 0.0Before running on your cluster, update the following:
- SLURM settings: Edit
MY_ACCTandMY_QOSinimage/submit_job.shandlanguage/submit_job.sh(and the correspondingwrite_*_submit_job.shscripts) to match your cluster's account and QOS. - Checkpoint paths: The
--ckpt_dirargument defaults to./checkpoints. Modify as needed for your storage setup.
scalable-curvature/
├── image/ # CIFAR-10 experiments
│ ├── train_fcn_image_sgd_dir_sharp.py
│ ├── train_fcn_image_adamw_dir_sharp.py
│ └── utils/
│ ├── critical_lr_cache_utils.py # Critical LR estimation
│ ├── sharpness_dir_utils.py # Directional sharpness (HVP-based)
│ ├── sharpness_cupy_utils.py # Hessian sharpness (LOBPCG)
│ └── optimizers.py # Custom SGD/AdamW with virtual_step
└── language/ # GPT experiments
├── train_gpt_adam_dir_sharp.py
└── utils/
├── critical_learning_rate.py
├── gpt.py
└── ...
@article{kalra2026scalable,
title={A Scalable Measure of Loss Landscape Curvature for Analyzing the Training Dynamics of LLMs},
author={Kalra, Dayal Singh and Gagnon-Audet, Jean-Christophe and Gromov, Andrey and Mediratta, Ishita and Niu, Kelvin and Miller, Alexander H and Shvartsman, Michael},
journal={arXiv preprint arXiv:2601.16979},
year={2026}
}