2020from torch import nn
2121
2222from ... import initialization as init
23- from ...activations import ACT2FN
2423from ...cache_utils import Cache , DynamicCache
2524from ...masking_utils import create_causal_mask
2625from ...modeling_outputs import MoeModelOutputWithPast
3130from ..ernie4_5 .modeling_ernie4_5 import Ernie4_5RotaryEmbedding , apply_rotary_pos_emb , rotate_half # noqa: F401
3231from ..llama .modeling_llama import LlamaAttention , LlamaRMSNorm
3332from ..mixtral .modeling_mixtral import (
33+ MixtralExperts ,
3434 MixtralForCausalLM ,
3535 MixtralPreTrainedModel ,
3636)
@@ -98,52 +98,36 @@ def forward(self, hidden_states):
9898 return hidden_states + self .e_score_correction_bias .squeeze ()
9999
100100
101- class Ernie4_5_MoeExperts (nn . Module ):
101+ class Ernie4_5_MoeExperts (MixtralExperts ):
102102 def __init__ (self , config ):
103103 super ().__init__ ()
104104 self .num_experts = config .moe_num_experts
105- self .hidden_dim = config .hidden_size
106105 self .intermediate_dim = config .moe_intermediate_size
107- self .use_bias = config .use_bias
108- self .act_fn = ACT2FN [config .hidden_act ]
109-
110- self .gate_up_proj = nn .Parameter (torch .zeros (self .num_experts , 2 * self .intermediate_dim , self .hidden_dim ))
111- self .down_proj = nn .Parameter (torch .zeros (self .num_experts , self .hidden_dim , self .intermediate_dim ))
112- if self .use_bias :
113- self .gate_up_proj_bias = nn .Parameter (torch .zeros (self .num_experts , 2 * self .intermediate_dim ))
114- self .down_proj_bias = nn .Parameter (torch .zeros (self .num_experts , self .hidden_dim ))
115- else :
116- self .gate_up_proj_bias = None
117- self .down_proj_bias = None
118106
119107 def forward (
120- self , hidden_states : torch .Tensor , selected_experts : torch .Tensor , routing_weights : torch .Tensor
108+ self ,
109+ hidden_states : torch .Tensor ,
110+ top_k_index : torch .Tensor ,
111+ top_k_weights : torch .Tensor ,
121112 ) -> torch .Tensor :
122113 final_hidden_states = torch .zeros_like (hidden_states )
123- if selected_experts . numel () == 0 :
124- return final_hidden_states
125-
126- expert_mask = torch .nn . functional . one_hot ( selected_experts , num_classes = self . num_experts ). permute ( 2 , 1 , 0 )
114+ with torch . no_grad () :
115+ expert_mask = torch . nn . functional . one_hot ( top_k_index , num_classes = self . num_experts )
116+ expert_mask = expert_mask . permute ( 2 , 1 , 0 )
117+ expert_hit = torch .greater ( expert_mask . sum ( dim = ( - 1 , - 2 )) , 0 ). nonzero ( )
127118
128- expert_hit = torch .greater (expert_mask .sum (dim = (- 1 , - 2 )), 0 ).nonzero ()
129119 for expert_idx in expert_hit :
130- expert_idx = int (expert_idx .item ())
131- idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
132- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_states .shape [- 1 ])
133- gate_inputs = F .linear (
134- current_state ,
135- self .gate_up_proj [expert_idx ],
136- None if self .gate_up_proj_bias is None else self .gate_up_proj_bias [expert_idx ],
137- )
138- gate , up = gate_inputs .chunk (2 , dim = - 1 )
120+ expert_idx = expert_idx [0 ]
121+ if expert_idx == self .num_experts :
122+ continue
123+ top_k_pos , token_idx = torch .where (expert_mask [expert_idx ])
124+ current_state = hidden_states [token_idx ]
125+ gate , up = nn .functional .linear (current_state , self .gate_up_proj [expert_idx ]).chunk (2 , dim = - 1 )
139126 current_hidden_states = self .act_fn (gate ) * up
140- current_hidden_states = F .linear (
141- current_hidden_states ,
142- self .down_proj [expert_idx ],
143- None if self .down_proj_bias is None else self .down_proj_bias [expert_idx ],
144- )
145- current_hidden_states = current_hidden_states * routing_weights [top_x , idx , None ]
146- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
127+ current_hidden_states = nn .functional .linear (current_hidden_states , self .down_proj [expert_idx ])
128+ current_hidden_states = current_hidden_states * top_k_weights [token_idx , top_k_pos , None ]
129+ final_hidden_states .index_add_ (0 , token_idx , current_hidden_states .to (final_hidden_states .dtype ))
130+
147131 return final_hidden_states
148132
149133
@@ -164,14 +148,14 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
164148
165149 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
166150 router_logits = F .linear (hidden_states .float (), self .weight )
167- routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
168- _ , selected_experts = torch .topk (self .moe_statics (routing_weights ), self .top_k , dim = - 1 )
169- routing_weights = torch .gather (routing_weights , dim = - 1 , index = selected_experts )
170- routing_weights = routing_weights / torch .clamp (
171- routing_weights .sum (dim = - 1 , keepdim = True ), min = self .norm_min
151+ router_logits = F .softmax (router_logits , dim = 1 , dtype = torch .float )
152+ router_top_value , router_indices = torch .topk (self .moe_statics (router_logits ), self .top_k , dim = - 1 )
153+ router_top_value = router_top_value / torch .clamp (
154+ router_top_value .sum (dim = - 1 , keepdim = True ), min = self .norm_min
172155 )
173- routing_weights = routing_weights .to (router_logits .dtype )
174- return routing_weights , selected_experts
156+ router_scores = router_top_value
157+ router_scores = router_scores .to (hidden_states .dtype )
158+ return router_logits , router_scores , router_indices
175159
176160
177161class Ernie4_5_MoeSparseMoeBlock (nn .Module ):
@@ -194,8 +178,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194178 if self .shared_experts is not None :
195179 shared_output = self .shared_experts (hidden_states )
196180
197- routing_weights , selected_experts = self .gate (hidden_states )
198- final_hidden_states = self .experts (hidden_states , selected_experts , routing_weights )
181+ _ , top_k_weights , top_k_index = self .gate (hidden_states )
182+ final_hidden_states = self .experts (hidden_states , top_k_index , top_k_weights )
199183
200184 if self .shared_experts is not None :
201185 final_hidden_states = final_hidden_states + shared_output
@@ -231,7 +215,7 @@ class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
231215 # Not supporting multi-token prediction (MTP) atm
232216 _keys_to_ignore_on_load_unexpected = ["mtp" ]
233217 _can_record_outputs = {
234- "router_logits" : OutputRecorder (Ernie4_5_MoeTopKRouter , layer_name = "mlp.gate" , index = 0 ),
218+ "router_logits" : OutputRecorder (Ernie4_5_MoeTopKRouter , index = 0 ),
235219 "hidden_states" : Ernie4_5_MoeDecoderLayer ,
236220 "attentions" : Ernie4_5_MoeAttention ,
237221 }
@@ -245,9 +229,6 @@ def _init_weights(self, module):
245229 elif isinstance (module , Ernie4_5_MoeExperts ):
246230 init .normal_ (module .gate_up_proj , mean = 0.0 , std = self .config .initializer_range )
247231 init .normal_ (module .down_proj , mean = 0.0 , std = self .config .initializer_range )
248- if module .gate_up_proj_bias is not None :
249- init .zeros_ (module .gate_up_proj_bias )
250- init .zeros_ (module .down_proj_bias )
251232
252233
253234@auto_docstring
0 commit comments