Skip to content

Commit 03a36b1

Browse files
authored
feat(turbo): Add turbo RMSNorm patch (#263)
1 parent 1e2e1b1 commit 03a36b1

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,3 +1068,19 @@ def combine_postprocess(self, hidden_states: torch.Tensor):
10681068
"""
10691069
hidden_states = self.deepep_dispatcher._post_combine(hidden_states)
10701070
return hidden_states.view(self.hidden_shape)
1071+
1072+
1073+
class PrimusTurboRMSNorm(te.pytorch.RMSNorm):
1074+
def __init__(self, *args, **kwargs):
1075+
assert "device" in kwargs
1076+
assert "dtype" in kwargs or "params_dtype" in kwargs, "device and dtype must be provided"
1077+
super().__init__(*args, **kwargs)
1078+
self.rms_norm_func = pt.modules.RMSNorm(
1079+
normalized_shape=kwargs["hidden_size"],
1080+
eps=self.eps,
1081+
device=kwargs["device"],
1082+
dtype=kwargs["dtype"] if "dtype" in kwargs else kwargs["params_dtype"],
1083+
)
1084+
1085+
def forward(self, x):
1086+
return self.rms_norm_func(x)

primus/configs/modules/megatron/primus_turbo.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ grouped_gemm_backend: "turbo-gg" # turbo-gg, lagacy-gg
2424

2525
# use turbo fused activation_with_probs to optmize redundant computation
2626
use_turbo_fused_act_with_probs: false
27+
28+
# layer norm
29+
use_turbo_rms_norm: false

primus/modules/trainer/megatron/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def patch_pt_replace_te(self, args):
192192
from primus.backends.megatron.core.extensions.primus_turbo import (
193193
PrimusTurboColumnParallelLinearTorch,
194194
PrimusTurboDeepEPTokenDispatcher,
195+
PrimusTurboRMSNorm,
195196
)
196197
from primus.backends.megatron.core.extensions.transformer_engine_spec_provider import (
197198
PrimusTurboSpecProvider,
@@ -223,6 +224,11 @@ def patch_pt_replace_te(self, args):
223224
token_dispatcher.MoEFlexTokenDispatcher = PrimusTurboDeepEPTokenDispatcher
224225
moe_layer.MoEFlexTokenDispatcher = PrimusTurboDeepEPTokenDispatcher
225226

227+
if args.use_turbo_rms_norm:
228+
import transformer_engine as te
229+
230+
te.pytorch.RMSNorm = PrimusTurboRMSNorm
231+
226232
def patch_fp8_context(self):
227233
from megatron.core import fp8_utils
228234
from megatron.core.ssm import mamba_block

0 commit comments

Comments
 (0)