|
9 | 9 |
|
10 | 10 | import grouped_gemm |
11 | 11 | import primus_turbo.pytorch as pt |
12 | | -import primus_turbo.pytorch.ops.activation as turbo_moe_activation |
13 | 12 | import torch |
14 | 13 | import torch.nn.functional as F |
15 | 14 | import transformer_engine as te |
|
37 | 36 | ScalingStrategy, |
38 | 37 | check_fp8_support, |
39 | 38 | ) |
40 | | -from primus_turbo.pytorch.ops.moe.tokens_per_expert_to_mask import ( |
41 | | - tokens_per_expert_to_mask as turbo_tokens_per_expert_to_mask, |
42 | | -) |
43 | 39 | from torch import Tensor |
44 | 40 | from transformer_engine.pytorch.fp8 import ( |
45 | 41 | DelayedScaling, |
@@ -756,17 +752,17 @@ def __init__( |
756 | 752 | assert self.config.gated_linear_unit, "turbo_fused_act_with_probs only support with GLU." |
757 | 753 |
|
758 | 754 | if self.config.activation_func == F.silu: |
759 | | - turbo_fused_act_with_probs = turbo_moe_activation.swiglu_with_probs |
| 755 | + turbo_fused_act_with_probs = pt.ops.swiglu_with_probs |
760 | 756 | elif self.config.activation_func == F.gelu: |
761 | | - turbo_fused_act_with_probs = turbo_moe_activation.geglu_with_probs |
| 757 | + turbo_fused_act_with_probs = pt.ops.geglu_with_probs |
762 | 758 | else: |
763 | 759 | raise ValueError("Activation function must be silu or gelu when using GroupedMLP.") |
764 | 760 |
|
765 | 761 | def _activation_func_with_probs(x, probs, tokens_per_experts): |
766 | 762 | assert x.ndim == 2 |
767 | 763 | assert probs.ndim == 1 |
768 | 764 | num_tokens = x.shape[0] |
769 | | - row_mask = turbo_tokens_per_expert_to_mask(tokens_per_experts, num_tokens) |
| 765 | + row_mask = pt.ops.tokens_per_expert_to_mask(tokens_per_experts, num_tokens) |
770 | 766 | return turbo_fused_act_with_probs(x, probs, row_mask) |
771 | 767 |
|
772 | 768 | self.activation_func_with_probs = _activation_func_with_probs |
|
0 commit comments