Description
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.
invididual components
system overview (for training)
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# in Python pseudocode, we want the following (for mxfp8):
# forward pass
# inputs are in high precision
x_hp, w_hp = ...
# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)
# backward pass
# inputs are in high precision
x_hp, w_hp, go_hp = ...
# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)
# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)
We want:
- the mx gemm to be fast
- the cast from high precision to mx (
to_mx
in pseudocode above) to be fast - the cast from high precision to mx to be fused to preceding/subsequent ops where possible
gemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
kernel | wrapper | current TFLOPs | peak TFLOPs | notes |
---|---|---|---|---|
mxfp8 cuBLAS | torch._scaled_mm |
TBD | 4.25 petaFLOPs | in progress, pytorch/pytorch#147548 |
mxfp8 CUTLASS | torchao.ops.mx_fp8_bf16 |
TBD | 4.25 petaFLOPs | landed, #1637 |
mxfp4 CUTLASS | torchao.ops.mx_fp4_bf16 |
TBD | 9.0 petaFLOPs | landed, #1661 |
nvfp4 cuBLAS | torch._scaled_mm |
TBD | 9.0 petaFLOPs | planned |
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use torch.compile
, same as we are doing with float8.
- we should ensure we can generate a single fused kernel for scaling and casting a tensor to mxfp8. Today, torch.compile generates two kernels: torch.compile cast to mxfp8 should only require one kernel #1769
- once we have a single fused kernel, we should make sure it's bandwidth bound. As of 2025-02-24, the casting to MX code is numerically correct but researchy and has not been optimized for performance. TODO issue.
- the
float8_e8m0fnu
dtype was added to PyTorch in add thetorch.float8_e8m0fnu
dtype to PyTorch pytorch#147466, we need to updatetorchao
to use this dtype for scales, and then ensure that PT2 works e2e. TODO issue - we need to ensure torch.compile is good at generating good fused kernels for the custom scale packing layout required by B200s. torch.compile cast to mxfp8 with blocked scales should be performant #1773
- we should ensure the cast across dim0 and dim1 is performant: mx cast to mxfp8 across dim0 and dim1 should be performant #1788
- given an MXLinear (fwd + bwd), we should expect at most six scale+cast kernels: two for each of
input
,weight
,grad_output
. The kernels forinput
andgrad_output
should be fused with preceding/subsequent ops as appropriate. TODO issue.
e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
- we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
- [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949
e2e inference performance
- need an inference roofline
- need to decide where to benchmark
Activity