Skip to content

MX single node performance tracker #1768

Open
@vkuzo

Description

@vkuzo

This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.

invididual components

system overview (for training)

# There are three gemms in a forward + backward of a Linear layer:
#
# 1.       input @ weight_t    = output     (forward pass)
# 2. grad_output @ weight      = grad_input (backward pass)
# 3.     input_t @ grad_output = grad_weight (backward pass)
# 
# in Python pseudocode, we want the following (for mxfp8):

# forward pass

# inputs are in high precision
x_hp, w_hp = ...

# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)

# backward pass

# inputs are in high precision
x_hp, w_hp, go_hp = ...

# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)

# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)

We want:

  1. the mx gemm to be fast
  2. the cast from high precision to mx (to_mx in pseudocode above) to be fast
  3. the cast from high precision to mx to be fused to preceding/subsequent ops where possible

gemm kernel

Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)

kernel wrapper current TFLOPs peak TFLOPs notes
mxfp8 cuBLAS torch._scaled_mm TBD 4.25 petaFLOPs in progress, pytorch/pytorch#147548
mxfp8 CUTLASS torchao.ops.mx_fp8_bf16 TBD 4.25 petaFLOPs landed, #1637
mxfp4 CUTLASS torchao.ops.mx_fp4_bf16 TBD 9.0 petaFLOPs landed, #1661
nvfp4 cuBLAS torch._scaled_mm TBD 9.0 petaFLOPs planned

Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.

scaling/casting kernels

Our current plan is to use torch.compile, same as we are doing with float8.

e2e training performance

From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).

  • we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
  • [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949

e2e inference performance

  • need an inference roofline
  • need to decide where to benchmark

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions