Skip to content

Commit 407b161

Browse files
author
Kye
committed
[CLEANUP]
1 parent 3a8791c commit 407b161

4 files changed

Lines changed: 63 additions & 73 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ from mh_moe.main import MHMoE
1616

1717
# Define model parameters
1818
dim = 512
19-
num_heads = 8
19+
heads = 8
2020
num_experts = 4
2121
num_layers = 3
2222

2323
# Create MHMoE model instance
24-
model = MHMoE(dim, num_heads, num_experts, num_layers)
24+
model = MHMoE(dim, heads, num_experts, num_layers)
2525

2626
# Generate dummy input
2727
batch_size = 10

example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
# Define model parameters
55
dim = 512
6-
num_heads = 8
6+
heads = 8
77
num_experts = 4
88
num_layers = 3
99

1010
# Create MHMoE model instance
11-
model = MHMoE(dim, num_heads, num_experts, num_layers)
11+
model = MHMoE(dim, heads, num_experts, num_layers)
1212

1313
# Generate dummy input
1414
batch_size = 10

mh_moe/main.py

Lines changed: 58 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,35 @@ def forward(self, x):
145145

146146

147147
class MHMoE(nn.Module):
148-
def __init__(self, dim, num_heads, num_experts, num_layers):
148+
"""
149+
Multi-Head Mixture of Experts (MHMoE) module.
150+
151+
Args:
152+
dim (int): The input dimension.
153+
heads (int): The number of attention heads.
154+
num_experts (int): The number of experts.
155+
num_layers (int): The number of layers.
156+
157+
Attributes:
158+
dim (int): The input dimension.
159+
heads (int): The number of attention heads.
160+
num_experts (int): The number of experts.
161+
num_layers (int): The number of layers.
162+
multi_head_layers (nn.ModuleList): List of multi-head layers.
163+
merge_layers (nn.ModuleList): List of merge layers.
164+
165+
"""
166+
167+
def __init__(
168+
self,
169+
dim: int,
170+
heads: int,
171+
num_experts: int = 6,
172+
num_layers: int = 3,
173+
):
149174
super(MHMoE, self).__init__()
150175
self.dim = dim
151-
self.num_heads = num_heads
176+
self.heads = heads
152177
self.num_experts = num_experts
153178
self.num_layers = num_layers
154179

@@ -170,102 +195,67 @@ def __init__(self, dim, num_heads, num_experts, num_layers):
170195
nn.init.constant_(self.merge_layers[i].bias, 0)
171196

172197
def forward(self, x, mask):
198+
"""
199+
Forward pass of the MHMoE module.
200+
201+
Args:
202+
x (torch.Tensor): The input tensor.
203+
mask (torch.Tensor): The mask tensor.
204+
205+
Returns:
206+
torch.Tensor: The output tensor.
207+
208+
"""
173209
# Loop through each layer
174210
for i in range(self.num_layers):
175211
x = self.process_layer(x, mask, i)
176212
return x
177213

178-
# def process_layer(self, x, mask, layer_index):
179-
# batch_size, length, _ = x.size()
180-
181-
# # Processed by multi-head layer
182-
# x = self.multi_head_layers[layer_index](x)
183-
# print(x.shape)
184-
185-
# # Using einops to split and rearrange sub-tokens in parallel
186-
# x = rearrange(
187-
# x,
188-
# "b l (h d) -> (b h) l d",
189-
# h=self.num_heads,
190-
# d=self.dim // self.num_heads,
191-
# )
192-
# b, s, d = x.shape
193-
# print(x.shape)
194-
195-
# # Example routing logic (placeholder)
196-
# # x, i = NoisyTopkRouter(self.dim, self.num_experts, 2)(x) # Replace with actual routing logic
197-
# x, i = TopkRouter(d, self.num_experts, 2)(x)
198-
# print(x.shape)
199-
200-
# # Sparse Moe
201-
# # x, e = NormalSparseMoE(
202-
# # dim,
203-
# # num_experts=self.num_experts,
204-
# # # experts=
205-
# # )
206-
207-
# # Using einops to merge back to the original token form
208-
# x = rearrange(
209-
# x,
210-
# "(b h) l d -> b l (h d)",
211-
# b=batch_size,
212-
# h=self.num_heads,
213-
# d=self.dim // self.num_heads,
214-
# )
215-
216-
# # Output processed by merge layer
217-
# x = self.merge_layers[layer_index](x)
218-
219-
# return x
220-
221214
def process_layer(self, x, mask, layer_index):
215+
"""
216+
Process a single layer of the MHMoE module.
217+
218+
Args:
219+
x (torch.Tensor): The input tensor.
220+
mask (torch.Tensor): The mask tensor.
221+
layer_index (int): The index of the layer.
222+
223+
Returns:
224+
torch.Tensor: The output tensor.
225+
226+
"""
222227
batch_size, length, _ = x.size()
223228

224229
# Processed by multi-head layer
225230
x = self.multi_head_layers[layer_index](x)
226231

227232
# Correcting the reshaping step
228-
# We need to ensure x is reshaped to (batch_size, num_heads, length, dim/num_heads)
233+
# We need to ensure x is reshaped to (batch_size, heads, length, dim/heads)
229234
x = x.view(
230235
batch_size,
231236
length,
232-
self.num_heads,
233-
self.dim // self.num_heads,
237+
self.heads,
238+
self.dim // self.heads,
234239
)
235240
x = x.permute(
236241
0, 2, 1, 3
237-
).contiguous() # this rearranges to (batch_size, num_heads, length, dim/num_heads)
242+
).contiguous() # this rearranges to (batch_size, heads, length, dim/heads)
238243
x = x.view(
239-
batch_size * self.num_heads,
244+
batch_size * self.heads,
240245
length,
241-
self.dim // self.num_heads,
246+
self.dim // self.heads,
242247
)
243248
b, s, d = x.shape
244249
print(x.shape)
245250

246-
# Simulated expert processing (needs actual implementation)
247-
# For now, assume identity transformation
248-
# x = x # Replace with actual routing and processing logic
249-
# x, i = NoisyTopkRouter(d, self.num_experts, 2)(x)
250-
# print(x.shape)
251-
# x, e = NormalSparseMoE(
252-
# self.dim,
253-
# self.num_experts,
254-
# )
255-
# x = TopkRouter(
256-
# d,
257-
# self.num_experts,
258-
# top_k=4,
259-
# )(x)
260-
# x = reduce("b h l d -> b l (h d)", x, "mean")
261251
x = SparseMoE(d, self.num_experts, 2)(x)
262252

263253
# Reshape back to original form after processing
264254
x = x.view(
265255
batch_size,
266-
self.num_heads,
256+
self.heads,
267257
length,
268-
self.dim // self.num_heads,
258+
self.dim // self.heads,
269259
)
270260
x = x.permute(0, 2, 1, 3).contiguous()
271261
x = x.view(batch_size, length, self.dim)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "mh-moe"
7-
version = "0.0.1"
7+
version = "0.0.2"
88
description = "Paper - Pytorch"
99
license = "MIT"
1010
authors = ["Kye Gomez <kye@apac.ai>"]

0 commit comments

Comments
 (0)