Skip to content

[Bug] Unexpected performance drop with float8 training + compiling only nn.Linear layers + using selective per op AC #786

Open
@danielvegamyhre

Description

@danielvegamyhre

Summary

I'm doing some benchmarking with torchtitan on H100s to compare the experimental feature in #778 with other training configurations.

One important comparison is comparing:

  • float8nocompile (handwritten triton kernels for fp8 conversion of nn.Linear layers), and
  • production float8 training + using torch.compile on only the nn.Linear layers

I've ran this comparison using:

  • no AC
  • full AC
  • selective per op AC

However, specifically when using selective per op AC, there is a massive drop in performance when using the production float8 training + using torch.compile on only the nn.Linear layers, as compared to using torch.compile on the full model.

This does not occur when using no AC or full AC.

I would expect some performance degradation compiling only nn.Linear instead of the full model, but the drop off is massive (see screenshot of benchmarks below). TFLOPS drops from 386.58 down to 125.88!

Screenshot 2025-01-10 at 9 28 23 AM

I looked at the traces for these 2 runs, and found some surprising issues:

  1. In the forward pass of the slow configuration (compiling only nn.Linear layers) it seems like aten::cross_entropy_loss is running on CPU and not dispatching any CUDA kernels, causing it to take 267ms vs 71us in the fully compiled version (~3800x slowdown).

Only nn.Linears compiled (267ms):
267ms

Full model compiled (71us):
71us

  1. In the backward pass of the slow configuration (compiling only nn.Linear layers) there is an extremely long/slow FSDP::post_backward_reduce call that does not appear in the fully compiled version (or rather, it is orders of magnitude faster).

Only nn.Linears compiled:

backward-reduce-slow

Steps to reproduce

  1. Checkout [Not for land] Integrate float8nocompile, an experimental feature for high performance #778
  2. Edit training_configs/llama3_8b.toml to run prod float8 + fully compiled model + selective per op AC on H100s:

NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh

  1. Edit training_configs/llama3_8b.toml to run prod float8 + only linear layers compiled + selective per op AC on H100s (don't think # of GPUs matters):

TORCHTITAN_COMPILE_LINEAR_ONLY=1 NGPU=4 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh

cc @vkuzo @soulitzer

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions