Skip to content

Commit 7e212cf

Browse files
feat: megatron support turbo fp8 grouped gemm
1 parent 4e8d1fc commit 7e212cf

File tree

9 files changed

+39
-43
lines changed

9 files changed

+39
-43
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
pull_request:
1111

1212
env:
13-
PRIMUS_TURBO_COMMIT: 3acc0fe3271b7e41f0646382311f36bdffca0554 # feat(permute): permute op support to compute tokens_per_expert (#140)
13+
PRIMUS_TURBO_COMMIT: 0385cdd615cb4eff28a1cbbf3583fccf95d11fe9 # chore: refactor grouped gemm blockwise python code (#142)
1414

1515
jobs:
1616
code-lint:

examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ modules:
8484
use_turbo_attention: true
8585
use_turbo_grouped_mlp: true
8686

87+
# fp8: e4m3
88+
# fp8_recipe: blockwise # tensorwise, blockwise
89+
8790
# deepep
8891
use_turbo_deepep: false
8992

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from contextlib import contextmanager
88
from typing import Callable, List, Optional, Tuple
99

10-
import grouped_gemm
1110
import primus_turbo.pytorch as pt
1211
import torch
1312
import torch.nn.functional as F
@@ -731,22 +730,15 @@ def __init__(
731730
pg_collection,
732731
)
733732
args = get_args()
734-
grouped_gemm_backend = args.grouped_gemm_backend
735-
self.grouped_gemm_backend = grouped_gemm_backend
736733

737734
if args.patch_zero_bubble and args.enable_zero_bubble:
738735
from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
739736

740737
self.grouped_gemm = functools.partial(
741-
grouped_gemm_with_weight_gradient_store, gg_backend=grouped_gemm_backend
738+
grouped_gemm_with_weight_gradient_store, gg_backend="turbo-gg"
742739
)
743740
else:
744-
if grouped_gemm_backend == "turbo-gg":
745-
self.grouped_gemm = pt.ops.grouped_gemm
746-
elif grouped_gemm_backend == "lagacy-gg":
747-
self.grouped_gemm = grouped_gemm.ops.gmm
748-
else:
749-
raise NotImplementedError(f"Grouped gemm backend {grouped_gemm_backend} not implemented")
741+
self.grouped_gemm = pt.ops.grouped_gemm
750742

751743
if args.use_turbo_fused_act_with_probs:
752744
assert self.config.gated_linear_unit, "turbo_fused_act_with_probs only support with GLU."
@@ -802,13 +794,18 @@ def forward(
802794
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
803795
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
804796

805-
if self.grouped_gemm_backend == "turbo-gg":
806-
tokens_per_expert = tokens_per_expert.cuda()
797+
tokens_per_expert = tokens_per_expert.to(w1.device)
807798
assert w1.is_contiguous(), "w1 must be contiguous"
808799
assert w2.is_contiguous(), "w2 must be contiguous"
809-
fc1_output = self.grouped_gemm(
810-
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, **(gemm_kargs[0])
811-
)
800+
if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled():
801+
quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config()
802+
fc1_output = pt.ops.grouped_gemm_fp8(
803+
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, config=quant_config
804+
)
805+
else:
806+
fc1_output = self.grouped_gemm(
807+
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, **(gemm_kargs[0])
808+
)
812809
if self.activation_recompute:
813810
if args.use_turbo_fused_act_with_probs:
814811
intermediate_parallel = self.activation_checkpoint.checkpoint(
@@ -821,9 +818,15 @@ def forward(
821818
intermediate_parallel = self.activation_checkpoint.checkpoint(
822819
self.activation_func_with_probs, fc1_output, permuted_probs.unsqueeze(-1)
823820
)
824-
fc2_output = self.grouped_gemm(
825-
intermediate_parallel, w2, tokens_per_expert, trans_b=False, **(gemm_kargs[1])
826-
)
821+
if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled():
822+
quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config()
823+
fc2_output = pt.ops.grouped_gemm_fp8(
824+
intermediate_parallel, w2, tokens_per_expert, trans_b=False, config=quant_config
825+
)
826+
else:
827+
fc2_output = self.grouped_gemm(
828+
intermediate_parallel, w2, tokens_per_expert, trans_b=False, **(gemm_kargs[1])
829+
)
827830
self.activation_checkpoint.discard_output_and_register_recompute(fc2_output)
828831
else:
829832
if args.use_turbo_fused_act_with_probs:
@@ -834,9 +837,15 @@ def forward(
834837
intermediate_parallel = self.activation_func_with_probs(
835838
fc1_output, permuted_probs.unsqueeze(-1)
836839
)
837-
fc2_output = self.grouped_gemm(
838-
intermediate_parallel, w2, tokens_per_expert, trans_b=False, **(gemm_kargs[1])
839-
)
840+
if PrimusTurboFP8GlobalStateManager.is_turbo_fp8_enabled():
841+
quant_config = PrimusTurboFP8GlobalStateManager.get_turbo_fp8_quant_config()
842+
fc2_output = pt.ops.grouped_gemm_fp8(
843+
intermediate_parallel, w2, tokens_per_expert, trans_b=False, config=quant_config
844+
)
845+
else:
846+
fc2_output = self.grouped_gemm(
847+
intermediate_parallel, w2, tokens_per_expert, trans_b=False, **(gemm_kargs[1])
848+
)
840849
else:
841850
# No token is allocated for local experts.
842851
assert torch.count_nonzero(tokens_per_expert) == 0
@@ -925,9 +934,7 @@ def __init__(
925934
deepep_num_use_cu=args.turbo_deepep_num_cu,
926935
deepep_num_worst_tokens=num_worst_tokens,
927936
deepep_use_cuda_num_tokens_per_expert=(
928-
args.use_turbo_grouped_mlp
929-
and args.moe_use_legacy_grouped_gemm
930-
and args.grouped_gemm_backend == "turbo-gg"
937+
args.use_turbo_grouped_mlp and args.moe_use_legacy_grouped_gemm
931938
),
932939
deepep_async_finish=True,
933940
deepep_allocate_on_comm_stream=True,

primus/backends/megatron/core/pipeline_parallel/zerobubble/README.md

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,8 @@ Zero bubbles is a state-of-art technique aiming to reduce the bubble time and me
99

1010
- Zero bubble patch the gemm OP and the grouped gemm OP for splitting the backward of the inputs and weights, support TE & Primus-turbo backend.
1111

12-
- We suggest to use primus-turbo gemm & grouped gemm to patch the original TE implementation, the following flags is needed to turn on.
13-
```
14-
enable_primus_turbo: true
15-
use_turbo_parallel_linear: true
16-
use_turbo_grouped_mlp: true
17-
```
18-
- If it is for MoE model, you can specify group gemm backend by `grouped_gemm_backend: "turbo-gg" # turbo-gg, lagacy-gg`.
19-
2012
- Some other flags need to be specified
21-
```
13+
2214
overlap_grad_reduce: false
2315
overlap_param_gather: false
2416
no_persist_layer_norm: true

primus/configs/models/megatron/llama4_17B128E.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ tokenizer_model: meta-llama/Llama-4-Maverick-17B-128E
2626
expert_model_parallel_size: 8
2727
expert_tensor_parallel_size: null # int
2828
moe_permute_fusion: true
29-
moe_shared_expert_overlap: true
29+
moe_shared_expert_overlap: true

primus/configs/models/megatron/llama4_17B16E.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ tokenizer_model: meta-llama/Llama-4-Scout-17B-16E
3030
expert_model_parallel_size: 8
3131
expert_tensor_parallel_size: null # int
3232
moe_permute_fusion: true
33-
moe_shared_expert_overlap: true
33+
moe_shared_expert_overlap: true

primus/configs/models/megatron/llama4_base.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ moe_router_topk: 1
1212
# moe_router_pre_softmax needs to be set to be true for moe_router_topk=1
1313
moe_router_pre_softmax: true
1414
moe_router_load_balancing_type: aux_loss
15-
moe_aux_loss_coeff: 0.001
15+
moe_aux_loss_coeff: 0.001
1616
moe_grouped_gemm: true
1717
moe_use_legacy_grouped_gemm: false
1818
moe_token_dispatcher_type: alltoall
19-
20-

primus/configs/modules/megatron/primus_turbo.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,5 @@ turbo_deepep_use_comm_stream: false
1919
# sync-free moe
2020
turbo_sync_free_moe_stage: 0
2121

22-
# group-gemm
23-
grouped_gemm_backend: "turbo-gg" # turbo-gg, lagacy-gg
24-
2522
# use turbo fused activation_with_probs to optmize redundant computation
2623
use_turbo_fused_act_with_probs: false

tests/unit_tests/megatron/transformer/moe/test_token_dispatcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
def create_args():
3535
"""Setup dummy args."""
3636
args = SimpleNamespace()
37-
args.grouped_gemm_backend = "turbo-gg"
3837
args.turbo_sync_free_moe_stage = 0
3938
args.sequence_parallel = False
4039
args.seq_length = 4096

0 commit comments

Comments
 (0)