@@ -65,19 +65,22 @@ def __init__(self, model: nn.Module):
6565 self .experts = Qwen3MoeMLP (model .experts )
6666
6767 def get_masked_routing_weights (self , router_logits ):
68- # routing_weights: (batch * sequence_length, n_experts)
69- routing_weights = torch .nn .functional .softmax (router_logits , dim = 1 , dtype = torch .float )
68+ if self .norm_topk_prob :
69+ selected_weights , selected_experts = torch .topk (router_logits , k = self .top_k , dim = - 1 )
70+ selected_weights = torch .nn .functional .softmax (selected_weights , dim = 1 , dtype = torch .float )
71+ masked_routing_weights = torch .zeros_like (router_logits , dtype = torch .float32 )
72+ masked_routing_weights .scatter_ (1 , selected_experts , selected_weights )
73+ else :
74+ # routing_weights: (batch * sequence_length, n_experts)
75+ routing_weights = torch .nn .functional .softmax (router_logits , dim = 1 , dtype = torch .float )
76+
77+ # selected_experts: (batch * sequence_length, top_k)
78+ selected_weights , selected_experts = torch .topk (routing_weights , k = self .top_k , dim = - 1 )
79+ mask = torch .zeros_like (routing_weights , dtype = torch .float32 )
80+ un_mask = torch .ones_like (selected_experts , dtype = torch .float32 )
81+ mask .scatter_ (1 , selected_experts , un_mask )
82+ masked_routing_weights = routing_weights * mask
7083
71- # selected_experts: (batch * sequence_length, top_k)
72- selected_weights , selected_experts = torch .topk (routing_weights , k = self .top_k , dim = - 1 )
73- mask = torch .zeros_like (routing_weights , dtype = torch .float32 )
74- un_mask = torch .ones_like (selected_experts , dtype = torch .float32 )
75- mask .scatter_ (1 , selected_experts , un_mask )
76-
77- if self .norm_topk_prob : # only diff with mixtral sparse moe block!
78- routing_weights /= selected_weights .sum (dim = - 1 , keepdim = True )
79-
80- masked_routing_weights = routing_weights * mask
8184
8285 ## get size per expert
8386 expert = router_logits .shape [1 ]
0 commit comments