@@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
132
132
renormalize : bool ,
133
133
num_expert_group : int = 0 ,
134
134
topk_group : int = 0 ,
135
+ routed_scaling_factor : float = 1.0 ,
135
136
scoring_func : str = "sigmoid" ,
136
137
e_score_correction_bias : Optional [torch .Tensor ] = None ):
137
138
@@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
163
164
score_mask = group_mask .unsqueeze (- 1 ).expand (
164
165
num_token , num_expert_group ,
165
166
scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
166
- tmp_scores = scores .masked_fill (~ score_mask .bool (),
167
- float ("-inf" )) # [n, e]
167
+ tmp_scores = scores .masked_fill (~ score_mask .bool (), 0.0 )
168
+ # float("-inf")) # [n, e]
168
169
169
170
if e_score_correction_bias is not None :
170
171
topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
@@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
176
177
dim = - 1 ,
177
178
sorted = False )
178
179
179
- if renormalize :
180
- topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
181
-
180
+ if topk > 1 and renormalize :
181
+ denominator = topk_weights .sum (dim = - 1 , keepdim = True ) + 1e-20
182
+ topk_weights = topk_weights / denominator
183
+ topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
182
184
return topk_ids .to (torch .long ), topk_weights .to (torch .float32 )
183
185
184
186
class KMoEGateDeepSeekV3 (BaseInjectedModule , KMoEGateBase ):
@@ -204,6 +206,7 @@ def __init__(
204
206
self .is_windows = os .name == 'nt'
205
207
self .use_quant = use_quant
206
208
if not self .is_windows and use_quant :
209
+ print ("injecting gate_linear" )
207
210
self .gate_linear = nn .Linear (self .gating_dim , self .n_routed_experts , device = generate_device )
208
211
self .gate_linear = KTransformersLinear (key + ".ffn_gate_inp" ,
209
212
gguf_loader , config , self .gate_linear , #orig_module
@@ -219,14 +222,13 @@ def forward(self, hidden_states) -> torch.Tensor:
219
222
### compute gating score
220
223
hidden_states = hidden_states .view (- 1 , h )
221
224
if self .use_quant :
222
- logits = self .gate_linear .forward (logits )
225
+ logits = self .gate_linear .forward (hidden_states )
223
226
else :
224
227
logits = F .linear (
225
228
hidden_states .type (torch .float32 ), self .weight .type (torch .float32 ), None
226
229
)
227
-
228
- return grouped_topk (hidden_states , logits , self .top_k , self .norm_topk_prob ,
229
- self .n_group , self .topk_group , "sigmoid" , self .e_score_correction_bias )
230
+ return grouped_topk (hidden_states , logits , self .top_k , self .norm_topk_prob , self .n_group ,
231
+ self .topk_group , self .routed_scaling_factor , "sigmoid" , self .e_score_correction_bias )
230
232
231
233
def load (self , w : dict | nn .Parameter | tuple | None = None , device : str | None = None ):
232
234
if device is None : device = self .device
0 commit comments