Skip to content

Commit 339c737

Browse files
committed
[prototype] Expert Parallel
ghstack-source-id: d03719eb6b659c319631bed9b276d6bac6e7df8d Pull Request resolved: #714
1 parent 4d182a1 commit 339c737

File tree

6 files changed

+833
-9
lines changed

6 files changed

+833
-9
lines changed

torchtitan/config_manager.py

+7
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def __init__(self):
348348
default=1,
349349
help="Context parallelism degree. 1 means disabled.",
350350
)
351+
self.parser.add_argument(
352+
"--experimental.expert_parallel_mode",
353+
type=str,
354+
default="none",
355+
choices=["none", "tp", "tp2ep", "dp2ep"],
356+
help="Expert Parallel mode",
357+
)
351358
self.parser.add_argument(
352359
"--training.mixed_precision_param",
353360
type=str,

torchtitan/models/llama/model.py

+64-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ class ModelArgs:
3535
depth_init: bool = True
3636
norm_type: str = "rmsnorm"
3737

38+
# MoE args
39+
enable_moe: bool = True
40+
num_experts: int = 8
41+
capacity_factor: float = 1.0
42+
use_shared_expert: bool = True
43+
auto_scale_hidden_dim: bool = True
44+
3845

3946
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
4047
"""
@@ -284,12 +291,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
284291
self.n_heads = model_args.n_heads
285292
self.dim = model_args.dim
286293
self.attention = Attention(model_args)
287-
self.feed_forward = FeedForward(
288-
dim=model_args.dim,
289-
hidden_dim=4 * model_args.dim,
290-
multiple_of=model_args.multiple_of,
291-
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
292-
)
294+
self.enable_moe = model_args.enable_moe
295+
296+
if not self.enable_moe:
297+
self.feed_forward = FeedForward(
298+
dim=model_args.dim,
299+
hidden_dim=4 * model_args.dim,
300+
multiple_of=model_args.multiple_of,
301+
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
302+
)
303+
else:
304+
from torchtitan.models.llama.moe_layers import (
305+
ExpertChoiceTopKRouter,
306+
GroupedExperts,
307+
MoE,
308+
)
309+
310+
hidden_dim_denom = 1
311+
if model_args.auto_scale_hidden_dim:
312+
hidden_dim_denom = model_args.capacity_factor + int(
313+
model_args.use_shared_expert
314+
)
315+
316+
dim = model_args.dim
317+
hidden_dim = 4 * model_args.dim
318+
hidden_dim = int(2 * hidden_dim / 3)
319+
if model_args.ffn_dim_multiplier is not None:
320+
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
321+
if model_args.auto_scale_hidden_dim:
322+
hidden_dim = int(hidden_dim / hidden_dim_denom)
323+
hidden_dim += -hidden_dim % model_args.multiple_of
324+
325+
num_experts = model_args.num_experts
326+
self.moe = MoE(
327+
experts=GroupedExperts(
328+
dim_in=dim, dim_out=hidden_dim, num_experts=num_experts
329+
),
330+
router=ExpertChoiceTopKRouter(
331+
gate=nn.Linear(dim, num_experts, bias=False),
332+
dim=dim,
333+
num_experts=num_experts,
334+
capacity_factor=model_args.capacity_factor,
335+
),
336+
shared_expert=(
337+
GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1)
338+
if model_args.use_shared_expert
339+
else None
340+
),
341+
)
342+
293343
self.layer_id = layer_id
294344
self.num_layers = model_args.n_layers
295345

@@ -322,14 +372,20 @@ def forward(
322372
323373
"""
324374
h = x + self.attention(self.attention_norm(x), freqs_cis)
325-
out = h + self.feed_forward(self.ffn_norm(h))
375+
if not self.enable_moe:
376+
out = h + self.feed_forward(self.ffn_norm(h))
377+
else:
378+
out = h + self.moe(self.ffn_norm(h))
326379
return out
327380

328381
def init_weights(self):
329382
for norm in (self.attention_norm, self.ffn_norm):
330383
norm.reset_parameters()
331384
self.attention.init_weights(self.weight_init_std)
332-
self.feed_forward.init_weights(self.weight_init_std)
385+
if not self.enable_moe:
386+
self.feed_forward.init_weights(self.weight_init_std)
387+
else:
388+
self.moe.init_weights(self.weight_init_std)
333389

334390

335391
class Transformer(nn.Module):

torchtitan/models/llama/moe_layers.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, NamedTuple, Optional
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class GroupedExperts(nn.Module):
15+
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
16+
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
17+
18+
Args:
19+
dim_in (int): Input dimension.
20+
dim_out (int): Output dimension.
21+
num_experts (int): Number of experts in this grouped experts layer. Default is 1.
22+
swiglu (bool): Whether to use gated linear unit. Default is True.
23+
activation (nn.Module): Activation function to use. Default is F.silu.
24+
"""
25+
26+
def __init__(
27+
self,
28+
*,
29+
dim_in: int,
30+
dim_out: int,
31+
num_experts: int = 1,
32+
swiglu: bool = True,
33+
activation: Callable = F.silu,
34+
):
35+
super().__init__()
36+
self.dim_in = dim_in
37+
self.num_experts = num_experts
38+
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
39+
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
40+
if swiglu:
41+
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
42+
self.act_fn = F.silu
43+
else:
44+
self.up_proj = None
45+
self.act_fn = activation
46+
47+
def forward(
48+
self,
49+
x: torch.Tensor,
50+
) -> torch.Tensor:
51+
"""
52+
Args:
53+
x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
54+
55+
Returns:
56+
torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
57+
"""
58+
# Expert Choice(EC) forward
59+
# x shape (num_experts, tokens_per_expert, dim_in)
60+
h = self.act_fn(torch.bmm(x, self.gate_proj))
61+
if self.up_proj is not None:
62+
h = h * torch.bmm(x, self.up_proj)
63+
# out shape (num_experts, tokens_per_expert, dim_out)
64+
out = torch.bmm(h, self.down_proj)
65+
return out
66+
67+
def init_weights(self, init_std: float):
68+
nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02)
69+
if self.up_proj is not None:
70+
nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std)
71+
nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std)
72+
73+
74+
class RouterOutput(NamedTuple):
75+
"""Router output for Expert Choice routing.
76+
77+
routed_input (torch.Tensor): tokens grouped together by experts indices with shape
78+
``(num_experts*tokens_per_expert, dim)`` for Expert Choice.
79+
token_indices (torch.Tensor): token indices for routed_input with shape
80+
``(num_experts*tokens_per_expert, dim)`` for Expert Choice.
81+
"""
82+
83+
routed_input: torch.Tensor
84+
token_indices: torch.Tensor
85+
86+
87+
class ExpertChoiceTopKRouter(nn.Module):
88+
"""This class implements experts choice routing. Each experts will select it's top K tokens based on
89+
the router scores. Refer to more details in https://arxiv.org/abs/2202.09368
90+
91+
Args:
92+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
93+
dim (int): Dimension of input tokens.
94+
num_experts (int): Number of experts in each moe layer.
95+
capacity_factor (float): Capacity factor determines how many tokens each expert can choose.
96+
expert capacity = (number of tokens * capacity factor) / number of experts.
97+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
98+
"""
99+
100+
def __init__(
101+
self,
102+
*,
103+
gate: nn.Module,
104+
dim: int,
105+
num_experts: int,
106+
capacity_factor: float,
107+
use_sigmoid: bool = True,
108+
):
109+
super().__init__()
110+
self.gate = gate
111+
self.dim = dim
112+
self.num_experts = num_experts
113+
self.capacity_factor = capacity_factor
114+
self.use_sigmoid = use_sigmoid
115+
116+
def forward(self, x: torch.Tensor) -> RouterOutput:
117+
"""
118+
Args:
119+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
120+
121+
Returns:
122+
routed_input (torch.Tensor): input tokens grouped together by experts indices with shape
123+
``(num_experts*tokens_per_expert, dim)``.
124+
token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``.
125+
"""
126+
# scores shape (num_experts, bs*slen)
127+
scores = self.gate(x).transpose(0, 1)
128+
# By default, we perform sigmoid and softmax in float32 to avoid loss explosion.
129+
if self.use_sigmoid:
130+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
131+
else:
132+
scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype)
133+
tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts)
134+
tokens_per_expert += -tokens_per_expert % 8
135+
# Take the smaller of tokens_per_expert and the number of tokens
136+
tokens_per_expert = min(tokens_per_expert, x.shape[0])
137+
138+
# top_scores shape (num_experts, tokens_per_expert)
139+
top_scores, selected_token_indices = torch.topk(
140+
scores, k=tokens_per_expert, dim=1
141+
)
142+
143+
# token_indices shape (num_experts*tokens_per_expert, dim)
144+
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, self.dim)
145+
# routed_input shape (num_experts*tokens_per_expert, dim)
146+
routed_input = torch.gather(x, dim=0, index=token_indices)
147+
routed_input = routed_input * top_scores.reshape(-1, 1)
148+
return RouterOutput(
149+
routed_input,
150+
token_indices,
151+
)
152+
153+
def init_weights(self, init_std: float):
154+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
155+
156+
157+
class MoE(nn.Module):
158+
"""This class implements the moe layer which is Mixture of Experts. Mixture of Experts
159+
typically consists of a set of expert networks, alongside with a router, which directs input tokens
160+
to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204.
161+
162+
Args:
163+
experts (nn.Module): experts module.
164+
router (nn.Module): router module.
165+
shared_expert (Optional[nn.Module]): shared expert module. Default is None.
166+
"""
167+
168+
def __init__(
169+
self,
170+
*,
171+
experts: nn.Module,
172+
router: nn.Module,
173+
shared_expert: Optional[nn.Module] = None,
174+
):
175+
super().__init__()
176+
self.experts = experts
177+
self.router = router
178+
self.shared_expert = shared_expert
179+
180+
def forward(self, x: torch.Tensor) -> torch.Tensor:
181+
"""
182+
Args:
183+
x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``.
184+
185+
Returns:
186+
out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``.
187+
"""
188+
bz, slen, dim = x.shape
189+
# routed_input shape (num_experts*tokens_per_expert, dim) for EC
190+
routed_input, token_indices = self.router(x.reshape(bz * slen, dim))
191+
192+
# routed_input shape (num_experts, tokens_per_expert, dim_in)
193+
routed_input = routed_input.reshape(self.router.num_experts, -1, dim)
194+
# routed_output shape (num_experts, tokens_per_expert, dim_out)
195+
routed_output = self.experts(routed_input)
196+
# routed_output shape (num_experts*tokens_per_expert, dim_out)
197+
routed_output = routed_output.reshape(-1, dim)
198+
199+
# shared expert
200+
if self.shared_expert is not None:
201+
out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape(
202+
bz * slen, dim
203+
)
204+
else:
205+
out = torch.zeros_like(x.reshape(bz * slen, dim))
206+
207+
# add experts output
208+
# doing in in place might be faster
209+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
210+
out = out.reshape(bz, slen, dim)
211+
return out
212+
213+
def init_weights(self, init_std: float):
214+
self.experts.init_weights(init_std)
215+
self.router.init_weights(init_std)
216+
if self.shared_expert is not None:
217+
self.shared_expert.init_weights(init_std)

0 commit comments

Comments
 (0)