Skip to content

Commit b34e019

Browse files
committed
fix(deepep): fix moe overlap error with sync-free moe 2 and 3.
1 parent 5f40d8c commit b34e019

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import grouped_gemm
1111
import primus_turbo.pytorch as pt
12-
import primus_turbo.pytorch.ops.activation as turbo_moe_activation
1312
import torch
1413
import torch.nn.functional as F
1514
import transformer_engine as te
@@ -37,9 +36,6 @@
3736
ScalingStrategy,
3837
check_fp8_support,
3938
)
40-
from primus_turbo.pytorch.ops.moe.tokens_per_expert_to_mask import (
41-
tokens_per_expert_to_mask as turbo_tokens_per_expert_to_mask,
42-
)
4339
from torch import Tensor
4440
from transformer_engine.pytorch.fp8 import (
4541
DelayedScaling,
@@ -756,17 +752,17 @@ def __init__(
756752
assert self.config.gated_linear_unit, "turbo_fused_act_with_probs only support with GLU."
757753

758754
if self.config.activation_func == F.silu:
759-
turbo_fused_act_with_probs = turbo_moe_activation.swiglu_with_probs
755+
turbo_fused_act_with_probs = pt.ops.swiglu_with_probs
760756
elif self.config.activation_func == F.gelu:
761-
turbo_fused_act_with_probs = turbo_moe_activation.geglu_with_probs
757+
turbo_fused_act_with_probs = pt.ops.geglu_with_probs
762758
else:
763759
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
764760

765761
def _activation_func_with_probs(x, probs, tokens_per_experts):
766762
assert x.ndim == 2
767763
assert probs.ndim == 1
768764
num_tokens = x.shape[0]
769-
row_mask = turbo_tokens_per_expert_to_mask(tokens_per_experts, num_tokens)
765+
row_mask = pt.ops.tokens_per_expert_to_mask(tokens_per_experts, num_tokens)
770766
return turbo_fused_act_with_probs(x, probs, row_mask)
771767

772768
self.activation_func_with_probs = _activation_func_with_probs

primus/modules/trainer/megatron/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,11 @@ def validate_args_on_rocm(args):
430430
# sync-free MoE
431431
if args.turbo_sync_free_moe_stage > 0:
432432
assert args.enable_primus_turbo, "Please set `enable_primus_turbo=True` to enable sync-free MoE."
433-
assert (
434-
args.turbo_sync_free_moe_stage > 1 and args.moe_use_legacy_grouped_gemm
435-
), "Sync-Free MoE require PrimusTurboGroupedMLP, please set `moe_use_legacy_grouped_gemm=True`"
436433

434+
if args.turbo_sync_free_moe_stage > 1 and not args.moe_use_legacy_grouped_gemm:
435+
raise ValueError(
436+
"Sync-Free MoE stage 2 or 3 require PrimusTurboGroupedMLP, please set `moe_use_legacy_grouped_gemm=True`"
437+
)
437438
options = _get_sync_free_moe_options(args.turbo_sync_free_moe_stage)
438439
print_rank_last(
439440
f"========== Enable Sync-Free MoE Stage {args.turbo_sync_free_moe_stage} (Auto-Enabled Options) =========="

0 commit comments

Comments
 (0)