Skip to content

Commit 952c3b7

Browse files
committed
fix norm order
1 parent 59b4fe2 commit 952c3b7

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,22 @@ def __init__(self, model: nn.Module):
6565
self.experts = Qwen3MoeMLP(model.experts)
6666

6767
def get_masked_routing_weights(self, router_logits):
68-
# routing_weights: (batch * sequence_length, n_experts)
69-
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
68+
if self.norm_topk_prob:
69+
selected_weights, selected_experts = torch.topk(router_logits, k=self.top_k, dim=-1)
70+
selected_weights = torch.nn.functional.softmax(selected_weights, dim=1, dtype=torch.float)
71+
masked_routing_weights = torch.zeros_like(router_logits, dtype=torch.float32)
72+
masked_routing_weights.scatter_(1, selected_experts, selected_weights)
73+
else:
74+
# routing_weights: (batch * sequence_length, n_experts)
75+
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
76+
77+
# selected_experts: (batch * sequence_length, top_k)
78+
selected_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1)
79+
mask = torch.zeros_like(routing_weights, dtype=torch.float32)
80+
un_mask = torch.ones_like(selected_experts, dtype=torch.float32)
81+
mask.scatter_(1, selected_experts, un_mask)
82+
masked_routing_weights = routing_weights * mask
7083

71-
# selected_experts: (batch * sequence_length, top_k)
72-
selected_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1)
73-
mask = torch.zeros_like(routing_weights, dtype=torch.float32)
74-
un_mask = torch.ones_like(selected_experts, dtype=torch.float32)
75-
mask.scatter_(1, selected_experts, un_mask)
76-
77-
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
78-
routing_weights /= selected_weights.sum(dim=-1, keepdim=True)
79-
80-
masked_routing_weights = routing_weights * mask
8184

8285
## get size per expert
8386
expert = router_logits.shape[1]

0 commit comments

Comments
 (0)