Skip to content

Commit 12a0615

Browse files
committed
[prototype] Expert Parallel
ghstack-source-id: b4d3f46f9519f4a478fca22b5665bf72bfe01409 Pull Request resolved: #714
1 parent 4d182a1 commit 12a0615

File tree

6 files changed

+851
-12
lines changed

6 files changed

+851
-12
lines changed

torchtitan/config_manager.py

+7
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def __init__(self):
348348
default=1,
349349
help="Context parallelism degree. 1 means disabled.",
350350
)
351+
self.parser.add_argument(
352+
"--experimental.expert_parallel_mode",
353+
type=str,
354+
default="none",
355+
choices=["none", "tp", "tp2ep", "dp2ep"],
356+
help="Expert Parallel mode",
357+
)
351358
self.parser.add_argument(
352359
"--training.mixed_precision_param",
353360
type=str,

torchtitan/models/llama/model.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn.functional as F
1616
from torch import nn
17+
1718
from torchtitan.models.norms import build_norm
1819

1920

@@ -35,6 +36,13 @@ class ModelArgs:
3536
depth_init: bool = True
3637
norm_type: str = "rmsnorm"
3738

39+
# MoE args
40+
enable_moe: bool = True
41+
num_experts: int = 8
42+
capacity_factor: float = 1.0
43+
use_shared_expert: bool = True
44+
auto_scale_hidden_dim: bool = True
45+
3846

3947
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
4048
"""
@@ -284,12 +292,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
284292
self.n_heads = model_args.n_heads
285293
self.dim = model_args.dim
286294
self.attention = Attention(model_args)
287-
self.feed_forward = FeedForward(
288-
dim=model_args.dim,
289-
hidden_dim=4 * model_args.dim,
290-
multiple_of=model_args.multiple_of,
291-
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
292-
)
295+
self.enable_moe = model_args.enable_moe
296+
297+
if not self.enable_moe:
298+
self.feed_forward = FeedForward(
299+
dim=model_args.dim,
300+
hidden_dim=4 * model_args.dim,
301+
multiple_of=model_args.multiple_of,
302+
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
303+
)
304+
else:
305+
from torchtitan.models.llama.moe_layers import (
306+
ExpertChoiceTopKRouter,
307+
GroupedExperts,
308+
MoE,
309+
)
310+
311+
hidden_dim_denom = 1
312+
if model_args.auto_scale_hidden_dim:
313+
hidden_dim_denom = model_args.capacity_factor + int(
314+
model_args.use_shared_expert
315+
)
316+
317+
dim = model_args.dim
318+
hidden_dim = 4 * model_args.dim
319+
hidden_dim = int(2 * hidden_dim / 3)
320+
if model_args.ffn_dim_multiplier is not None:
321+
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
322+
if model_args.auto_scale_hidden_dim:
323+
hidden_dim = int(hidden_dim / hidden_dim_denom)
324+
hidden_dim += -hidden_dim % model_args.multiple_of
325+
326+
num_experts = model_args.num_experts
327+
self.moe = MoE(
328+
experts=GroupedExperts(
329+
dim_in=dim, dim_out=hidden_dim, num_experts=num_experts
330+
),
331+
router=ExpertChoiceTopKRouter(
332+
gate=nn.Linear(dim, num_experts, bias=False),
333+
dim=dim,
334+
num_experts=num_experts,
335+
capacity_factor=model_args.capacity_factor,
336+
),
337+
shared_expert=(
338+
GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1)
339+
if model_args.use_shared_expert
340+
else None
341+
),
342+
)
343+
293344
self.layer_id = layer_id
294345
self.num_layers = model_args.n_layers
295346

@@ -322,14 +373,20 @@ def forward(
322373
323374
"""
324375
h = x + self.attention(self.attention_norm(x), freqs_cis)
325-
out = h + self.feed_forward(self.ffn_norm(h))
376+
if not self.enable_moe:
377+
out = h + self.feed_forward(self.ffn_norm(h))
378+
else:
379+
out = h + self.moe(self.ffn_norm(h))
326380
return out
327381

328382
def init_weights(self):
329383
for norm in (self.attention_norm, self.ffn_norm):
330384
norm.reset_parameters()
331385
self.attention.init_weights(self.weight_init_std)
332-
self.feed_forward.init_weights(self.weight_init_std)
386+
if not self.enable_moe:
387+
self.feed_forward.init_weights(self.weight_init_std)
388+
else:
389+
self.moe.init_weights(self.weight_init_std)
333390

334391

335392
class Transformer(nn.Module):

torchtitan/models/llama/moe_layers.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, NamedTuple, Optional
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class GroupedExperts(nn.Module):
15+
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
16+
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
17+
18+
Args:
19+
dim_in (int): Input dimension.
20+
dim_out (int): Output dimension.
21+
num_experts (int): Number of experts in this grouped experts layer. Default is 1.
22+
swiglu (bool): Whether to use gated linear unit. Default is True.
23+
activation (nn.Module): Activation function to use. Default is F.silu.
24+
"""
25+
26+
def __init__(
27+
self,
28+
*,
29+
dim_in: int,
30+
dim_out: int,
31+
num_experts: int = 1,
32+
swiglu: bool = True,
33+
activation: Callable = F.silu,
34+
):
35+
super().__init__()
36+
self.dim_in = dim_in
37+
self.num_experts = num_experts
38+
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
39+
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
40+
if swiglu:
41+
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
42+
self.act_fn = F.silu
43+
else:
44+
self.up_proj = None
45+
self.act_fn = activation
46+
47+
def forward(
48+
self,
49+
x: torch.Tensor,
50+
) -> torch.Tensor:
51+
"""
52+
Args:
53+
x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
54+
55+
Returns:
56+
torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
57+
"""
58+
# Expert Choice(EC) forward
59+
# x shape (num_experts, tokens_per_expert, dim_in)
60+
h = self.act_fn(torch.bmm(x, self.gate_proj))
61+
if self.up_proj is not None:
62+
h = h * torch.bmm(x, self.up_proj)
63+
# out shape (num_experts, tokens_per_expert, dim_out)
64+
out = torch.bmm(h, self.down_proj)
65+
return out
66+
67+
def init_weights(self, init_std: float):
68+
nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02)
69+
if self.up_proj is not None:
70+
nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std)
71+
nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std)
72+
73+
74+
class RouterOutput(NamedTuple):
75+
"""Router output for Expert Choice routing.
76+
77+
routed_input (torch.Tensor): tokens grouped together by experts indices with shape
78+
``(num_experts*tokens_per_expert, dim)`` for Expert Choice.
79+
token_indices (torch.Tensor): token indices for routed_input with shape
80+
``(num_experts*tokens_per_expert, dim)`` for Expert Choice.
81+
"""
82+
83+
routed_input: torch.Tensor
84+
token_indices: torch.Tensor
85+
86+
87+
class ExpertChoiceTopKRouter(nn.Module):
88+
"""This class implements experts choice routing. Each experts will select it's top K tokens based on
89+
the router scores. Refer to more details in https://arxiv.org/abs/2202.09368
90+
91+
Args:
92+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
93+
dim (int): Dimension of input tokens.
94+
num_experts (int): Number of experts in each moe layer.
95+
capacity_factor (float): Capacity factor determines how many tokens each expert can choose.
96+
expert capacity = (number of tokens * capacity factor) / number of experts.
97+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
98+
"""
99+
100+
def __init__(
101+
self,
102+
*,
103+
gate: nn.Module,
104+
dim: int,
105+
num_experts: int,
106+
capacity_factor: float,
107+
use_sigmoid: bool = True,
108+
):
109+
super().__init__()
110+
self.gate = gate
111+
self.dim = dim
112+
self.num_experts = num_experts
113+
self.capacity_factor = capacity_factor
114+
self.use_sigmoid = use_sigmoid
115+
116+
def forward(self, x: torch.Tensor) -> RouterOutput:
117+
"""
118+
Args:
119+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
120+
121+
Returns:
122+
routed_input (torch.Tensor): input tokens grouped together by experts indices with shape
123+
``(num_experts*tokens_per_expert, dim)``.
124+
token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``.
125+
"""
126+
# scores shape (num_experts, bs*slen)
127+
scores = self.gate(x).transpose(0, 1)
128+
# By default, we perform sigmoid and softmax in float32 to avoid loss explosion.
129+
if self.use_sigmoid:
130+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
131+
else:
132+
scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype)
133+
tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts)
134+
tokens_per_expert += -tokens_per_expert % 8
135+
# Take the smaller of tokens_per_expert and the number of tokens
136+
tokens_per_expert = min(tokens_per_expert, x.shape[0])
137+
138+
# top_scores shape (num_experts, tokens_per_expert)
139+
top_scores, selected_token_indices = torch.topk(
140+
scores, k=tokens_per_expert, dim=1
141+
)
142+
143+
# token_indices shape (num_experts*tokens_per_expert, dim)
144+
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, self.dim)
145+
# routed_input shape (num_experts*tokens_per_expert, dim)
146+
routed_input = torch.gather(x, dim=0, index=token_indices)
147+
routed_input = routed_input * top_scores.reshape(-1, 1)
148+
return RouterOutput(
149+
routed_input,
150+
token_indices,
151+
)
152+
153+
def init_weights(self, init_std: float):
154+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
155+
156+
157+
class MoE(nn.Module):
158+
"""This class implements the moe layer which is Mixture of Experts. Mixture of Experts
159+
typically consists of a set of expert networks, alongside with a router, which directs input tokens
160+
to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204.
161+
162+
Args:
163+
experts (nn.Module): experts module.
164+
router (nn.Module): router module.
165+
shared_expert (Optional[nn.Module]): shared expert module. Default is None.
166+
"""
167+
168+
def __init__(
169+
self,
170+
*,
171+
experts: nn.Module,
172+
router: nn.Module,
173+
shared_expert: Optional[nn.Module] = None,
174+
):
175+
super().__init__()
176+
self.experts = experts
177+
self.router = router
178+
self.shared_expert = shared_expert
179+
180+
def forward(self, x: torch.Tensor) -> torch.Tensor:
181+
"""
182+
Args:
183+
x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``.
184+
185+
Returns:
186+
out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``.
187+
"""
188+
bz, slen, dim = x.shape
189+
# routed_input shape (num_experts*tokens_per_expert, dim) for EC
190+
routed_input, token_indices = self.router(x.reshape(bz * slen, dim))
191+
192+
# routed_input shape (num_experts, tokens_per_expert, dim_in)
193+
routed_input = routed_input.reshape(self.router.num_experts, -1, dim)
194+
# routed_output shape (num_experts, tokens_per_expert, dim_out)
195+
routed_output = self.experts(routed_input)
196+
# routed_output shape (num_experts*tokens_per_expert, dim_out)
197+
routed_output = routed_output.reshape(-1, dim)
198+
199+
# shared expert
200+
if self.shared_expert is not None:
201+
out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape(
202+
bz * slen, dim
203+
)
204+
else:
205+
out = torch.zeros_like(x.reshape(bz * slen, dim))
206+
207+
# add experts output
208+
# doing in in place might be faster
209+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
210+
out = out.reshape(bz, slen, dim)
211+
return out
212+
213+
def init_weights(self, init_std: float):
214+
self.experts.init_weights(init_std)
215+
self.router.init_weights(init_std)
216+
if self.shared_expert is not None:
217+
self.shared_expert.init_weights(init_std)

0 commit comments

Comments
 (0)