Skip to content

Commit 0fa49db

Browse files
authored
Fix ernie moe (#42535)
* fix * style
1 parent bf3f0ae commit 0fa49db

File tree

2 files changed

+62
-94
lines changed

2 files changed

+62
-94
lines changed

src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -318,51 +318,41 @@ def forward(self, hidden_states):
318318

319319

320320
class Ernie4_5_MoeExperts(nn.Module):
321+
"""Collection of expert weights stored as 3D tensors."""
322+
321323
def __init__(self, config):
322324
super().__init__()
323325
self.num_experts = config.moe_num_experts
324326
self.hidden_dim = config.hidden_size
325327
self.intermediate_dim = config.moe_intermediate_size
326-
self.use_bias = config.use_bias
328+
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
329+
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
327330
self.act_fn = ACT2FN[config.hidden_act]
328331

329-
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
330-
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
331-
if self.use_bias:
332-
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
333-
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
334-
else:
335-
self.gate_up_proj_bias = None
336-
self.down_proj_bias = None
337-
338332
def forward(
339-
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
333+
self,
334+
hidden_states: torch.Tensor,
335+
top_k_index: torch.Tensor,
336+
top_k_weights: torch.Tensor,
340337
) -> torch.Tensor:
341338
final_hidden_states = torch.zeros_like(hidden_states)
342-
if selected_experts.numel() == 0:
343-
return final_hidden_states
339+
with torch.no_grad():
340+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
341+
expert_mask = expert_mask.permute(2, 1, 0)
342+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
344343

345-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
346-
347-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
348344
for expert_idx in expert_hit:
349-
expert_idx = int(expert_idx.item())
350-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
351-
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
352-
gate_inputs = F.linear(
353-
current_state,
354-
self.gate_up_proj[expert_idx],
355-
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
356-
)
357-
gate, up = gate_inputs.chunk(2, dim=-1)
345+
expert_idx = expert_idx[0]
346+
if expert_idx == self.num_experts:
347+
continue
348+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
349+
current_state = hidden_states[token_idx]
350+
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
358351
current_hidden_states = self.act_fn(gate) * up
359-
current_hidden_states = F.linear(
360-
current_hidden_states,
361-
self.down_proj[expert_idx],
362-
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
363-
)
364-
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
365-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
352+
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
353+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
354+
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
355+
366356
return final_hidden_states
367357

368358

@@ -383,14 +373,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
383373

384374
with torch.autocast(device_type=device_type, enabled=False): # Force float32
385375
router_logits = F.linear(hidden_states.float(), self.weight)
386-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
387-
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
388-
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
389-
routing_weights = routing_weights / torch.clamp(
390-
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
376+
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
377+
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
378+
router_top_value = router_top_value / torch.clamp(
379+
router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min
391380
)
392-
routing_weights = routing_weights.to(router_logits.dtype)
393-
return routing_weights, selected_experts
381+
router_scores = router_top_value
382+
router_scores = router_scores.to(hidden_states.dtype)
383+
return router_logits, router_scores, router_indices
394384

395385

396386
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
@@ -413,8 +403,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
413403
if self.shared_experts is not None:
414404
shared_output = self.shared_experts(hidden_states)
415405

416-
routing_weights, selected_experts = self.gate(hidden_states)
417-
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
406+
_, top_k_weights, top_k_index = self.gate(hidden_states)
407+
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
418408

419409
if self.shared_experts is not None:
420410
final_hidden_states = final_hidden_states + shared_output
@@ -489,7 +479,7 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
489479
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
490480
_supports_attention_backend = True
491481
_can_record_outputs = {
492-
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0),
482+
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, index=0),
493483
"hidden_states": Ernie4_5_MoeDecoderLayer,
494484
"attentions": Ernie4_5_MoeAttention,
495485
}
@@ -505,9 +495,6 @@ def _init_weights(self, module):
505495
elif isinstance(module, Ernie4_5_MoeExperts):
506496
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
507497
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
508-
if module.gate_up_proj_bias is not None:
509-
init.zeros_(module.gate_up_proj_bias)
510-
init.zeros_(module.down_proj_bias)
511498

512499

513500
@auto_docstring

src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py

Lines changed: 30 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch import nn
2121

2222
from ... import initialization as init
23-
from ...activations import ACT2FN
2423
from ...cache_utils import Cache, DynamicCache
2524
from ...masking_utils import create_causal_mask
2625
from ...modeling_outputs import MoeModelOutputWithPast
@@ -31,6 +30,7 @@
3130
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
3231
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
3332
from ..mixtral.modeling_mixtral import (
33+
MixtralExperts,
3434
MixtralForCausalLM,
3535
MixtralPreTrainedModel,
3636
)
@@ -98,52 +98,36 @@ def forward(self, hidden_states):
9898
return hidden_states + self.e_score_correction_bias.squeeze()
9999

100100

101-
class Ernie4_5_MoeExperts(nn.Module):
101+
class Ernie4_5_MoeExperts(MixtralExperts):
102102
def __init__(self, config):
103103
super().__init__()
104104
self.num_experts = config.moe_num_experts
105-
self.hidden_dim = config.hidden_size
106105
self.intermediate_dim = config.moe_intermediate_size
107-
self.use_bias = config.use_bias
108-
self.act_fn = ACT2FN[config.hidden_act]
109-
110-
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
111-
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
112-
if self.use_bias:
113-
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
114-
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
115-
else:
116-
self.gate_up_proj_bias = None
117-
self.down_proj_bias = None
118106

119107
def forward(
120-
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
108+
self,
109+
hidden_states: torch.Tensor,
110+
top_k_index: torch.Tensor,
111+
top_k_weights: torch.Tensor,
121112
) -> torch.Tensor:
122113
final_hidden_states = torch.zeros_like(hidden_states)
123-
if selected_experts.numel() == 0:
124-
return final_hidden_states
125-
126-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
114+
with torch.no_grad():
115+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
116+
expert_mask = expert_mask.permute(2, 1, 0)
117+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
127118

128-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
129119
for expert_idx in expert_hit:
130-
expert_idx = int(expert_idx.item())
131-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
132-
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
133-
gate_inputs = F.linear(
134-
current_state,
135-
self.gate_up_proj[expert_idx],
136-
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
137-
)
138-
gate, up = gate_inputs.chunk(2, dim=-1)
120+
expert_idx = expert_idx[0]
121+
if expert_idx == self.num_experts:
122+
continue
123+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
124+
current_state = hidden_states[token_idx]
125+
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
139126
current_hidden_states = self.act_fn(gate) * up
140-
current_hidden_states = F.linear(
141-
current_hidden_states,
142-
self.down_proj[expert_idx],
143-
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
144-
)
145-
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
146-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
127+
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
128+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
129+
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
130+
147131
return final_hidden_states
148132

149133

@@ -164,14 +148,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
164148

165149
with torch.autocast(device_type=device_type, enabled=False): # Force float32
166150
router_logits = F.linear(hidden_states.float(), self.weight)
167-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
168-
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
169-
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
170-
routing_weights = routing_weights / torch.clamp(
171-
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
151+
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
152+
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
153+
router_top_value = router_top_value / torch.clamp(
154+
router_top_value.sum(dim=-1, keepdim=True), min=self.norm_min
172155
)
173-
routing_weights = routing_weights.to(router_logits.dtype)
174-
return routing_weights, selected_experts
156+
router_scores = router_top_value
157+
router_scores = router_scores.to(hidden_states.dtype)
158+
return router_logits, router_scores, router_indices
175159

176160

177161
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
@@ -194,8 +178,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194178
if self.shared_experts is not None:
195179
shared_output = self.shared_experts(hidden_states)
196180

197-
routing_weights, selected_experts = self.gate(hidden_states)
198-
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
181+
_, top_k_weights, top_k_index = self.gate(hidden_states)
182+
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
199183

200184
if self.shared_experts is not None:
201185
final_hidden_states = final_hidden_states + shared_output
@@ -231,7 +215,7 @@ class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
231215
# Not supporting multi-token prediction (MTP) atm
232216
_keys_to_ignore_on_load_unexpected = ["mtp"]
233217
_can_record_outputs = {
234-
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0),
218+
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, index=0),
235219
"hidden_states": Ernie4_5_MoeDecoderLayer,
236220
"attentions": Ernie4_5_MoeAttention,
237221
}
@@ -245,9 +229,6 @@ def _init_weights(self, module):
245229
elif isinstance(module, Ernie4_5_MoeExperts):
246230
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
247231
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
248-
if module.gate_up_proj_bias is not None:
249-
init.zeros_(module.gate_up_proj_bias)
250-
init.zeros_(module.down_proj_bias)
251232

252233

253234
@auto_docstring

0 commit comments

Comments
 (0)