Description
Summary
I'm doing some benchmarking with torchtitan on H100s to compare the experimental feature in #778 with other training configurations.
One important comparison is comparing:
- float8nocompile (handwritten triton kernels for fp8 conversion of nn.Linear layers), and
- production float8 training + using torch.compile on only the nn.Linear layers
I've ran this comparison using:
- no AC
- full AC
- selective per op AC
However, specifically when using selective per op AC, there is a massive drop in performance when using the production float8 training + using torch.compile on only the nn.Linear layers, as compared to using torch.compile on the full model.
This does not occur when using no AC or full AC.
I would expect some performance degradation compiling only nn.Linear instead of the full model, but the drop off is massive (see screenshot of benchmarks below). TFLOPS drops from 386.58 down to 125.88!
I looked at the traces for these 2 runs, and found some surprising issues:
- In the forward pass of the slow configuration (compiling only nn.Linear layers) it seems like aten::cross_entropy_loss is running on CPU and not dispatching any CUDA kernels, causing it to take 267ms vs 71us in the fully compiled version (~3800x slowdown).
Only nn.Linears compiled (267ms):
- In the backward pass of the slow configuration (compiling only nn.Linear layers) there is an extremely long/slow
FSDP::post_backward_reduce
call that does not appear in the fully compiled version (or rather, it is orders of magnitude faster).
Only nn.Linears compiled:

Steps to reproduce
- Checkout [Not for land] Integrate float8nocompile, an experimental feature for high performance #778
- Edit
training_configs/llama3_8b.toml
to run prod float8 + fully compiled model + selective per op AC on H100s:
NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
- Edit
training_configs/llama3_8b.toml
to run prod float8 + only linear layers compiled + selective per op AC on H100s (don't think # of GPUs matters):
TORCHTITAN_COMPILE_LINEAR_ONLY=1 NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
cc @vkuzo @soulitzer