Open
Description
This is a running list of planned features for low precision training. As features are completed we plan to delete them from this list, to keep things simple.
MX
pytorch/pytorch
- [in progress] improving torch.compile performance for to_mx cast across dim1: request for faster inductor kernels for blockwise reduction across dim1 -> write pytorch#149982
pytorch/torchao
- [in progress] performance tracker: MX single node performance tracker #1768
float8
performance
- [in progress] optimize torch.compile performance for float8 tensorwise scaling/casting kernels
- [in progress] ensure that float8 rowwise scaling is performant with TP and async TP [Async TP] Fuse all-gather-matmuls for float8 rowwise training pytorch#149990
distributed
- [planned] verify integration with PP
new features
- [planned] MoE support: [roadmap/tracker] Low precision MoE training #2147
- [planned] float8 SDPA (priority TBD)
ecosystem
- [in progress] enable TP in torchtune's float8 integration: TP + FP8 + Compile: metadata error torchtune#2682
other
- [planned] expose float8 training via the quantize_ API
- [planned] migrate
torchao.float8
code totorchao.quantization
for better unification with the rest of torchao, in a BC-preserving way: [wip] unification of torchao.float8 with the rest of torchao #894