Skip to content

Commit 8be56a0

Browse files
authored
Merge pull request #927 from kvcache-ai/fix-gate-precision
Update gate.py
2 parents 6ca233c + b453333 commit 8be56a0

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

ktransformers/operators/gate.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
132132
renormalize: bool,
133133
num_expert_group: int = 0,
134134
topk_group: int = 0,
135+
routed_scaling_factor: float = 1.0,
135136
scoring_func: str = "sigmoid",
136137
e_score_correction_bias: Optional[torch.Tensor] = None):
137138

@@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
163164
score_mask = group_mask.unsqueeze(-1).expand(
164165
num_token, num_expert_group,
165166
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
166-
tmp_scores = scores.masked_fill(~score_mask.bool(),
167-
float("-inf")) # [n, e]
167+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
168+
#float("-inf")) # [n, e]
168169

169170
if e_score_correction_bias is not None:
170171
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
@@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
176177
dim=-1,
177178
sorted=False)
178179

179-
if renormalize:
180-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
181-
180+
if topk > 1 and renormalize:
181+
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
182+
topk_weights = topk_weights / denominator
183+
topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
182184
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
183185

184186
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
@@ -204,6 +206,7 @@ def __init__(
204206
self.is_windows = os.name == 'nt'
205207
self.use_quant = use_quant
206208
if not self.is_windows and use_quant:
209+
print("injecting gate_linear")
207210
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
208211
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
209212
gguf_loader, config, self.gate_linear, #orig_module
@@ -219,14 +222,13 @@ def forward(self, hidden_states) -> torch.Tensor:
219222
### compute gating score
220223
hidden_states = hidden_states.view(-1, h)
221224
if self.use_quant:
222-
logits = self.gate_linear.forward(logits)
225+
logits = self.gate_linear.forward(hidden_states)
223226
else:
224227
logits = F.linear(
225228
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
226229
)
227-
228-
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
229-
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
230+
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group,
231+
self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias)
230232

231233
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
232234
if device is None: device = self.device

0 commit comments

Comments
 (0)