Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] N*G Triton group gemm for MoE #960

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Mar 14, 2025

This PR adds a Triton Group GEMM with full backwards pass support, for integration with MoE training.
The forward pass is from FBGemm experimental:
https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

1 - Numerics with BF16 have been verified on sample sizes and the core DeepSeek v3 shapes.

2025-03-13 16:40:08,294 - INFO - Gradient shapes - grad_x: torch.Size([1024, 256]), grad_w: torch.Size([10240, 256])
2025-03-13 16:40:08,294 - INFO - Running PyTorch reference implementation
2025-03-13 16:40:08,591 - INFO - Comparing gradients with PyTorch reference
2025-03-13 16:40:08,602 - INFO - Maximum gradient error - grad_x: 0.125, grad_w: 0.125
2025-03-13 16:40:08,641 - INFO - ✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)
2025-03-13 16:40:08,641 - INFO - ✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)
2025-03-13 16:40:08,641 - INFO - Gradients allclose check - grad_x: True, grad_w: True
2025-03-13 16:40:08,641 - INFO - ✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)

2 - Todos:
a - fp8 support
b - TMA (was removed to focus on numerics)
c - WS
d - Perf and auto-tuning

3 - Integration - ready now for BF16 though may want to do perf work first.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 14, 2025
@lessw2020 lessw2020 self-assigned this Mar 14, 2025
@lessw2020 lessw2020 changed the title [WIP] Triton group gemm for MoE [WIP] N*G Triton group gemm for MoE Mar 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants