@@ -120,13 +120,13 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120120 hidden_states = hidden_states .repeat (num_experts , 1 )
121121 hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
122122
123- gate_up = torch .bmm (hidden_states , self .gate_up_proj ) + self .gate_up_proj_bias [..., None , :]
123+ gate_up = torch .bmm (hidden_states , self .gate_up_proj . to ( hidden_states . dtype )) + self .gate_up_proj_bias [..., None , :]. to ( hidden_states . dtype )
124124 gate , up = gate_up [..., ::2 ], gate_up [..., 1 ::2 ]
125125 gate = gate .clamp (min = None , max = self .limit )
126126 up = up .clamp (min = - self .limit , max = self .limit )
127127 glu = gate * torch .sigmoid (gate * self .alpha )
128- next_states = torch .bmm (((up + 1.0 ) * glu ), self .down_proj )
129- next_states = next_states + self .down_proj_bias [..., None , :]
128+ next_states = torch .bmm (((up + 1.0 ) * glu ), self .down_proj . to ( hidden_states . dtype ) )
129+ next_states = next_states + self .down_proj_bias [..., None , :]. to ( hidden_states . dtype )
130130 next_states = next_states .view (num_experts , batch_size , - 1 , self .hidden_size )
131131 next_states = next_states * routing_weights .transpose (0 , 1 ).view (num_experts , batch_size , - 1 )[..., None ]
132132 next_states = next_states .sum (dim = 0 )
0 commit comments