Skip to content

Commit ea7aefd

Browse files
authored
fix: gMLP uses full bias instead of truncated bias (#1371)
* fix: correct undefined self.args in TopKTokenChoiceRouter * fix: use correctly sliced bias in gMLP projection
1 parent 19f294d commit ea7aefd

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

megatron/model/gmlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(self, x, attention_mask):
8080
mask = torch.ones(weight.shape[:2], device=device).triu_(1).bool()
8181
weight = weight.masked_fill(mask, 0.0)
8282

83-
gate = F.linear(gate.transpose(2, 1), weight, self.proj.bias).transpose(2, 1)
83+
gate = F.linear(gate.transpose(2, 1), weight, bias).transpose(2, 1)
8484

8585
if self.use_attn:
8686
gate = gate + self.attn(x, attention_mask)

megatron/model/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ def jitter(self, x):
223223
Returns:
224224
torch.Tensor: Jittered input tensor.
225225
"""
226-
low = 1.0 - self.args.moe_jitter_eps
227-
high = 1.0 + self.args.moe_jitter_eps
226+
low = 1.0 - self.jitter_eps
227+
high = 1.0 + self.jitter_eps
228228
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
229229
return low + noise * (high - low)
230230

0 commit comments

Comments
 (0)