Skip to content

[mxfp8 training] cuda kernel for unpadding token groups#4021

Merged
danielvegamyhre merged 1 commit intomainfrom
danielvegamyhre/stack/148
Mar 10, 2026
Merged

[mxfp8 training] cuda kernel for unpadding token groups#4021
danielvegamyhre merged 1 commit intomainfrom
danielvegamyhre/stack/148

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Mar 6, 2026

Stacked PRs:


[mxfp8 training] cuda kernel for unpadding token groups

Context

  • Prior PR on the stack adds support for dynamic per-group padding to prepare for the MXFP8 group GEMM.
  • Next, we need a fast kernel to remove this padding from both the forward output and the input gradient in the backward pass after its computation.

Summary of changes

  • Add a CUDA kernel which does this dynamic per group un-padding operation
  • Add tests and benchmarks
  • Integrate into _to_mxfp8_then_scaled_grouped_mm autograd func

Tests

  • pytest test/prototype/moe_training/test_kernels.py -k cuda_fuse -s
  • pytest test/prototype/moe_training/test_training.py
  • pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -v -s -k

Benchmarks

Unpadding kernel microbenchmark:

  num_tokens    dim    num_groups    torch_us    cuda_us    torch_mem_bw_gbps    cuda_mem_bw_gbps  cuda_vs_torch
------------  -----  ------------  ----------  ---------  -------------------  ------------------  ---------------
       16384   1536             1     95.9431    15.0122              1049.2              6705.41  6.39x
       16384   1536             4    101.101     15.7422               996.64             6400.72  6.42x
       16384   1536             8    108.321     16.289                933.84             6210.02  6.65x
       16384   1536            16    132.617     16.5108               764.24             6138.51  8.03x
       16384   2048             1     93.282     20.7243              1438.84             6476.34  4.50x
       16384   2048             4    100.678     20.7693              1335.75             6474.95  4.85x
       16384   2048             8    109.496     21.0035              1229.37             6408.96  5.21x
       16384   2048            16    136.428     21.0374               991.49             6429.81  6.49x
       16384   5120             1    125.727     55.6127              2668.84             6033.59  2.26x
       16384   5120             4    127.386     55.9834              2636.65             5999.5   2.28x
       16384   5120             8    131.976     56.0214              2552.4              6012.97  2.36x
       16384   5120            16    151.393     56.1884              2235.85             6024.26  2.69x
       16384   7168             1    143.205     76.4491              3280.35             6144.77  1.87x
       16384   7168             4    149.303     76.9112              3152.51             6119.78  1.94x
       16384   7168             8    155.439     76.856               3031.01             6130.15  2.02x
       16384   7168            16    169.923     76.7989              2788.86             6170.55  2.21x

Net speedup of fwd + bwd of _to_mxfp8_then_scaled_grouped_mm:

M,N,K,G                  recipe                             bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  -------------------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(32768, 8192, 5120, 1)   MXFP8TrainingRecipe.MXFP8_RCEIL            7482.1               5084.26  1.472x                         2039.87           1388.51  1.469x
(32768, 8192, 5120, 2)   MXFP8TrainingRecipe.MXFP8_RCEIL            6908                 5118.98  1.349x                         2210.74           1420.38  1.556x
(128000, 8192, 5120, 1)  MXFP8TrainingRecipe.MXFP8_RCEIL           27650.2              20061.1   1.378x                         9600.61           5530.1   1.736x
(128000, 8192, 5120, 2)  MXFP8TrainingRecipe.MXFP8_RCEIL           30024.9              19885     1.51x                          8401.94           5495.78  1.529x
(32768, 2048, 7168, 4)   MXFP8TrainingRecipe.MXFP8_RCEIL            2436.16              2622.46  0.929x                          748.112           798.72  0.937x
(32768, 2048, 7168, 8)   MXFP8TrainingRecipe.MXFP8_RCEIL            2360.35              2647.97  0.891x                          760.992           859.68  0.885x
(128000, 2048, 7168, 4)  MXFP8TrainingRecipe.MXFP8_RCEIL            9487.97              9288.78  1.021x                         3095.71           2957.38  1.047x
(128000, 2048, 7168, 8)  MXFP8TrainingRecipe.MXFP8_RCEIL           10594.5               9267.17  1.143x                         3037.22           2988.03  1.016x

stack-info: PR: #4021, branch: danielvegamyhre/stack/148
danielvegamyhre added a commit that referenced this pull request Mar 6, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4021

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 1489ddc with merge base f0d0deb (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 8736765 to ea52e6b Compare March 6, 2026 22:43
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 6, 2026
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 00:48
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 00:48
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from ea52e6b to ba3532f Compare March 7, 2026 00:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 00:49
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 00:49
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 00:58
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 00:58
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from ba3532f to 163b2ef Compare March 7, 2026 00:58
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 00:58
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 00:58
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 01:00
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 01:00
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 163b2ef to 69b8876 Compare March 7, 2026 01:00
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 01:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 01:01
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 01:18
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 01:18
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 69b8876 to d5df99c Compare March 7, 2026 01:18
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 01:18
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 04:35
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 04:35
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 798d484 to 5e9272e Compare March 7, 2026 04:35
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 04:35
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 04:35
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 04:40
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 04:40
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 5e9272e to 26ae5c7 Compare March 7, 2026 04:40
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 04:40
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 04:40
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 05:46
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 05:46
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 26ae5c7 to be7192e Compare March 7, 2026 05:46
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 05:46
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 05:46
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 06:00
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 7, 2026 06:00
danielvegamyhre added a commit that referenced this pull request Mar 7, 2026
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from be7192e to 2173ea9 Compare March 7, 2026 06:00
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/147 March 7, 2026 06:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 06:00
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/147 to main March 10, 2026 04:35
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/148 branch from 2173ea9 to 1489ddc Compare March 10, 2026 04:35
@danielvegamyhre
Copy link
Contributor Author

landing to unblock pytorch/torchtitan#2255 then taking another shot at ABI stable tomorrow, fyi @drisspg.

i tried the conversion for a while using the Claude prompt as Andrew and it was not going very smoothly because the _C_mxfp8 extension also includes mxfp8 quantization kernels that ABI stable doesn't support yet due to lack of support for fp8 dtypes. Mixing both ABI stable and not ABI stable seems like it should be possible but was getting build issues, will debug tomorrow.

@danielvegamyhre danielvegamyhre merged commit 7f82891 into main Mar 10, 2026
46 of 49 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants