Open
Description
configurability
- [planned] support rowwise/blockwise scaling granularity, configurable separately for each gemm
- [planned] configure settings for each of the three gemms in linear fwd/bwd separately
- [planned] support more fine grained configuration of how to apply
Float8Linear
to individual modules - [planned] inference support (see [RFC] Float8 Inference pytorch-labs/float8_experimental#314)
performance
- [in progress]
torch._scaled_mm
support for rowwise scaled float8 gemm- [done] eager mode support
- [planned] torch.compile support, backed by triton/cutlass
- [in progress] optimize torch.compile performance for float8 scaling/casting kernels
- [fixed behind a flag, off by default] Improve the Inductor generated kernel for the pattern of
output1 = pointwise(intput); output2 = transpose(output1)
pytorch#130015 - [planned] Improve inductor codegen for writing out tensor and tensor.t() in the same kernel pytorch#133242
- [planned] [Inductor] Fusion of Tiled Point-Wise and Reduction Operators pytorch#128063
- [planned] PT2 should leverage partial reductions to speed up larger reductions pytorch#136267
- [fixed behind a flag, off by default] Improve the Inductor generated kernel for the pattern of
distributed
- [in progress] integrate with FSDP2 with 16-bit or 8-bit all-gather with delayed scaling for weights
- POC is done, performance optimizations are ongoing
- [planned] verify integration with PP
other
- weight gradient accumulation in float32
- add
use_fast_accum
(float8 accumulation of gemm) option to UX - Allow for modifying the scaled_mm compute pytorch-labs/float8_experimental#144 - improve saturated casting performance
copied from pytorch-labs/float8_experimental#187