@@ -145,10 +145,35 @@ def forward(self, x):
145145
146146
147147class MHMoE (nn .Module ):
148- def __init__ (self , dim , num_heads , num_experts , num_layers ):
148+ """
149+ Multi-Head Mixture of Experts (MHMoE) module.
150+
151+ Args:
152+ dim (int): The input dimension.
153+ heads (int): The number of attention heads.
154+ num_experts (int): The number of experts.
155+ num_layers (int): The number of layers.
156+
157+ Attributes:
158+ dim (int): The input dimension.
159+ heads (int): The number of attention heads.
160+ num_experts (int): The number of experts.
161+ num_layers (int): The number of layers.
162+ multi_head_layers (nn.ModuleList): List of multi-head layers.
163+ merge_layers (nn.ModuleList): List of merge layers.
164+
165+ """
166+
167+ def __init__ (
168+ self ,
169+ dim : int ,
170+ heads : int ,
171+ num_experts : int = 6 ,
172+ num_layers : int = 3 ,
173+ ):
149174 super (MHMoE , self ).__init__ ()
150175 self .dim = dim
151- self .num_heads = num_heads
176+ self .heads = heads
152177 self .num_experts = num_experts
153178 self .num_layers = num_layers
154179
@@ -170,102 +195,67 @@ def __init__(self, dim, num_heads, num_experts, num_layers):
170195 nn .init .constant_ (self .merge_layers [i ].bias , 0 )
171196
172197 def forward (self , x , mask ):
198+ """
199+ Forward pass of the MHMoE module.
200+
201+ Args:
202+ x (torch.Tensor): The input tensor.
203+ mask (torch.Tensor): The mask tensor.
204+
205+ Returns:
206+ torch.Tensor: The output tensor.
207+
208+ """
173209 # Loop through each layer
174210 for i in range (self .num_layers ):
175211 x = self .process_layer (x , mask , i )
176212 return x
177213
178- # def process_layer(self, x, mask, layer_index):
179- # batch_size, length, _ = x.size()
180-
181- # # Processed by multi-head layer
182- # x = self.multi_head_layers[layer_index](x)
183- # print(x.shape)
184-
185- # # Using einops to split and rearrange sub-tokens in parallel
186- # x = rearrange(
187- # x,
188- # "b l (h d) -> (b h) l d",
189- # h=self.num_heads,
190- # d=self.dim // self.num_heads,
191- # )
192- # b, s, d = x.shape
193- # print(x.shape)
194-
195- # # Example routing logic (placeholder)
196- # # x, i = NoisyTopkRouter(self.dim, self.num_experts, 2)(x) # Replace with actual routing logic
197- # x, i = TopkRouter(d, self.num_experts, 2)(x)
198- # print(x.shape)
199-
200- # # Sparse Moe
201- # # x, e = NormalSparseMoE(
202- # # dim,
203- # # num_experts=self.num_experts,
204- # # # experts=
205- # # )
206-
207- # # Using einops to merge back to the original token form
208- # x = rearrange(
209- # x,
210- # "(b h) l d -> b l (h d)",
211- # b=batch_size,
212- # h=self.num_heads,
213- # d=self.dim // self.num_heads,
214- # )
215-
216- # # Output processed by merge layer
217- # x = self.merge_layers[layer_index](x)
218-
219- # return x
220-
221214 def process_layer (self , x , mask , layer_index ):
215+ """
216+ Process a single layer of the MHMoE module.
217+
218+ Args:
219+ x (torch.Tensor): The input tensor.
220+ mask (torch.Tensor): The mask tensor.
221+ layer_index (int): The index of the layer.
222+
223+ Returns:
224+ torch.Tensor: The output tensor.
225+
226+ """
222227 batch_size , length , _ = x .size ()
223228
224229 # Processed by multi-head layer
225230 x = self .multi_head_layers [layer_index ](x )
226231
227232 # Correcting the reshaping step
228- # We need to ensure x is reshaped to (batch_size, num_heads , length, dim/num_heads )
233+ # We need to ensure x is reshaped to (batch_size, heads , length, dim/heads )
229234 x = x .view (
230235 batch_size ,
231236 length ,
232- self .num_heads ,
233- self .dim // self .num_heads ,
237+ self .heads ,
238+ self .dim // self .heads ,
234239 )
235240 x = x .permute (
236241 0 , 2 , 1 , 3
237- ).contiguous () # this rearranges to (batch_size, num_heads , length, dim/num_heads )
242+ ).contiguous () # this rearranges to (batch_size, heads , length, dim/heads )
238243 x = x .view (
239- batch_size * self .num_heads ,
244+ batch_size * self .heads ,
240245 length ,
241- self .dim // self .num_heads ,
246+ self .dim // self .heads ,
242247 )
243248 b , s , d = x .shape
244249 print (x .shape )
245250
246- # Simulated expert processing (needs actual implementation)
247- # For now, assume identity transformation
248- # x = x # Replace with actual routing and processing logic
249- # x, i = NoisyTopkRouter(d, self.num_experts, 2)(x)
250- # print(x.shape)
251- # x, e = NormalSparseMoE(
252- # self.dim,
253- # self.num_experts,
254- # )
255- # x = TopkRouter(
256- # d,
257- # self.num_experts,
258- # top_k=4,
259- # )(x)
260- # x = reduce("b h l d -> b l (h d)", x, "mean")
261251 x = SparseMoE (d , self .num_experts , 2 )(x )
262252
263253 # Reshape back to original form after processing
264254 x = x .view (
265255 batch_size ,
266- self .num_heads ,
256+ self .heads ,
267257 length ,
268- self .dim // self .num_heads ,
258+ self .dim // self .heads ,
269259 )
270260 x = x .permute (0 , 2 , 1 , 3 ).contiguous ()
271261 x = x .view (batch_size , length , self .dim )
0 commit comments