Skip to content

Next steps for MXFP8 training #4022

@danielvegamyhre

Description

@danielvegamyhre

Functionality improvements

Performance improvements

  • [P0] Fused all2all dispatch + padding kernel
    • Needed to avoid the extra copy incurred by the standalone padding kernel described above, which hurt our speedup. The benefit of this approach is that, with the all-to-all dispatch, the receiver ranks are already going to be allocating a buffer for the incoming tokens. If we write those tokens to locations aligned with multiples of 32, we avoid the need for this expensive extra copy.
    • While we're doing this, we can also write incoming tokens grouped by local expert, instead of grouped by remote/source rank, in order to avoid the token shuffle kernel step
  • [P1] Faster 3d weight quantization kernel for backward pass dgrad computation with RCEIL scaling, that writes scales directly to ((32,4),4) layout for tcgen05 mma
    • Current: ~5 tb/s, goal: 6.4 tbs
    • Currently writes scales in row major, requires additional lightweight kernel for per group blocked layout
  • [P1] Faster dim0 quantization kernel with RCEIL scaling that writes scales directly to ((32,4),4) layout for tcgen05 mma
    • Current: ~5.5 tb/s, Goal: ~6.4 tb/s.
    • Currently writes scales in row major, requires additional lightweight kernel for per group blocked layout
      • Status: Not started 🔴
      • Owner: None
  • [P1] dim1 quantization kernel with RCEIL scaling that writes scales directly to ((32,4),4) layout for tcgen05 mma (current: 5.9 tb/s but writes scales in row major, requires extra kernel for per group blocked layout)
    • Currently writes scales in row major, requires additional lightweight kernel for per group blocked layout
      • Status: Not started 🔴
      • Owner: None

Sub-issues

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions