[mxfp8 training] cuda kernel for unpadding token groups#4021
[mxfp8 training] cuda kernel for unpadding token groups#4021danielvegamyhre merged 1 commit intomainfrom
Conversation
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4021
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 1489ddc with merge base f0d0deb ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8736765 to
ea52e6b
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
ea52e6b to
ba3532f
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
ba3532f to
163b2ef
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
163b2ef to
69b8876
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
69b8876 to
d5df99c
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
798d484 to
5e9272e
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
5e9272e to
26ae5c7
Compare
26ae5c7 to
be7192e
Compare
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
be7192e to
2173ea9
Compare
2173ea9 to
1489ddc
Compare
|
landing to unblock pytorch/torchtitan#2255 then taking another shot at ABI stable tomorrow, fyi @drisspg. i tried the conversion for a while using the Claude prompt as Andrew and it was not going very smoothly because the |
Stacked PRs:
[mxfp8 training] cuda kernel for unpadding token groups
Context
Summary of changes
_to_mxfp8_then_scaled_grouped_mmautograd funcTests
pytest test/prototype/moe_training/test_kernels.py -k cuda_fuse -spytest test/prototype/moe_training/test_training.pypytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -v -s -kBenchmarks
Unpadding kernel microbenchmark:
Net speedup of fwd + bwd of
_to_mxfp8_then_scaled_grouped_mm: