-
Notifications
You must be signed in to change notification settings - Fork 559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] MOE design in Torchtune #1902
base: main
Are you sure you want to change the base?
Changes from all commits
ecd8e5f
9c6cc1c
8da01d8
6422e12
321e262
cc4446b
8190a62
0c2060a
e12dab9
cfd3764
c76aef3
5d855bd
ed424d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
# [RFC] MOE design in Torchtune | ||
|
||
## Background | ||
This RFC proposes adding the MOE support in Torchtune. We want to design in a general way so that components can be easily swapped when implementing different MOE models. An MOE layer directly replaces the dense FFN layer in the transformer decoder layer and has two main components: router and experts. | ||
|
||
## Expert | ||
An expert is essentially an FFN layer similar to the original dense FFN layer in the transformer decoder layer. There are two kinds of experts: routed experts and shared experts. Each expert in the routed experts specializes in learning certain patterns/aspects, and only part of the routed experts will be activated. On the other hand, shared experts are always activated, aiming at capturing and consolidating common knowledge across varying contexts. | ||
|
||
**Here's the proposed Experts design in torchtune:** | ||
```python | ||
class Experts(nn.Module): | ||
def __init__(self, dim_in, dim_out, num_experts=1, swiglu=True, nonlinearity=None): | ||
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the implications of defining your experts all as a single parameter? Noob thoughts here; I think it makes sense that it would be faster for training since all experts are updated each step, but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (not an expert, but)I think the reason of combining all the experts is to use torch.bmm() and groupedGEMM instead of torch.matmul() to optimize performance. Good question, for inference we are always doing TC even if training uses EC. And for TC forward, we are essentially looping over the experts one at a time. |
||
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) | ||
if swiglu: | ||
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) | ||
self.act_fn = F.silu() | ||
else: | ||
self.up_proj = None | ||
self.act_fn = nonlinearity | ||
|
||
def forward(self, x, num_local_tokens_per_expert=None): | ||
''' | ||
inputs: | ||
x: input tokens | ||
shape [bs*slen*experts_per_token, hidden_dim] for TC forward | ||
shape [num_experts*tokens_per_expert, hidden_dim] for EC forward | ||
num_local_tokens_per_expert: number of tokens for each expert, only used for TC forward | ||
outputs: | ||
out: output tokens | ||
shape [bs*slen*experts_per_token, hidden_dim] for TC forward | ||
shape [num_experts*tokens_per_expert, hidden_dim] for EC forward | ||
''' | ||
# TC forward | ||
if num_local_tokens_per_expert is not None: | ||
# TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance | ||
# x shape [bs*slen*experts_per_token, hidden_dim] | ||
# x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] | ||
x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) | ||
out_expert_splits = [] | ||
for expert_index, x_expert_split in enumerate(x_expert_splits): | ||
gate_proj = self.gate_proj[expert_index] | ||
down_proj = self.down_proj[expert_index] | ||
up_proj = None | ||
if self.up_proj is not None: | ||
up_proj = self.up_proj[expert_index] | ||
|
||
h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) | ||
if up_proj is not None: | ||
h = h * torch.matmul(x_expert_split, up_proj) | ||
# [tokens_per_expert, hidden_dim] | ||
h = torch.matmul(h, down_proj) | ||
|
||
out_expert_splits.append(h) | ||
# shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] | ||
out = torch.cat(out_expert_splits, dim=0) | ||
# EC forward | ||
else: | ||
# x shape [num_experts, tokens_per_expert, hidden_dim] | ||
x = x.view(num_experts, -1, dim_in) | ||
h = self.act_fn(torch.bmm(x, self.gate_proj)) | ||
if self.up_proj is not None: | ||
h = h * torch.bmm(x, self.up_proj) | ||
out = torch.bmm(h, self.down_proj).view(-1, dim_in) | ||
return out | ||
|
||
# Expert builder for routed experts | ||
def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None): | ||
return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) | ||
|
||
# Single expert / shared expert | ||
def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None): | ||
return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity) | ||
``` | ||
|
||
## Router | ||
Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing. | ||
|
||
Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router then defines how tokens select experts based on router scores. | ||
|
||
**Here's the proposed Token Choice Routing design in torchtune:** | ||
```python | ||
class TokenChoiceTopKRouter(nn.Module): | ||
def __init__(self, hidden_dim, num_experts, experts_per_token): | ||
self.gate = nn.Linear(hidden_dim, num_experts) | ||
self.experts_per_token = experts_per_token | ||
|
||
def forward(self, x, use_sigmoid=False): | ||
''' | ||
input: | ||
x: input tokens | ||
shape [bs*slen, hidden_dim] | ||
outputs: | ||
routed_input: tokens gather by selected experts | ||
shape [bs*slen*experts_per_token, hidden_dim] | ||
token_indices: token indices sorted by selected experts indices | ||
num_local_tokens_per_expert: number of tokens assigned to each expert | ||
shape [num_experts,] | ||
''' | ||
# scores shape [bs*slen, num_experts] | ||
scores = self.gate(x) | ||
if use_sigmoid: | ||
scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) | ||
else: | ||
scores = F.softmax(scores.to(softmax_dtype), dim=1).to(x.dtype) | ||
|
||
# TODO: implement load balancing auxiliary loss for token choice routing | ||
# https://github.com/NVIDIA/Megatron-LM/blob/f1f039224584f0bc6ba89c21ef4f491d7136e3ce/megatron/core/transformer/moe/router.py#L162 | ||
|
||
# router scores/indices shape [bs*slen, experts_per_token] | ||
top_scores, selected_experts_indices = torch.topk(scores, k=self.experts_per_token, dim=1) | ||
top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) | ||
|
||
# shape [num_experts,]: how many tokens for each expert | ||
num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts) | ||
# shape [bs*slen*experts_per_token,] | ||
token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) | ||
# top_scores shape [bs*slen*experts_per_token,] | ||
top_scores = top_scores.view(-1)[token_indices_experts_sorted] | ||
|
||
# token_indices shape [bs*slen*experts_per_token, hidden_dim] | ||
token_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) | ||
# routed_input shape [bs*slen*experts_per_token, hidden_dim] | ||
routed_input = torch.gather(x, dim=0, index=token_indices) | ||
routed_input = routed_input * top_scores | ||
|
||
return routed_input, token_indices, num_local_tokens_per_expert | ||
``` | ||
|
||
However, token choice routing has several pitfalls according to the expert choice [paper](https://arxiv.org/pdf/2002.05202). | ||
1. Poor load balance. Experts can become under or over-specialized. Load imbalance can hurt step latency / inference time. | ||
2. Experts under specialization. Ideally the gating network will learn token-to-expert affinity such that similar or relevant tokens are routed to the same expert. However, a sub-optimal strategy can produce redundant experts and/or experts that are not sufficiently specialized. | ||
3. Same compute for each token. Token choice will allocate a fixed number of experts to each token regardless of the importance of different tokens. Ideally an MOE model should flexibly allocate compute resources based on the complexity of the input. | ||
|
||
Compared to **token choice**, **expert choice** topK routing lets experts select its top-k tokens. The ExpertChoiceTopKRouter class routes input tokens to different experts based on the router scores. | ||
|
||
**Here's the proposed Expert Choice Routing design in torchtune:** | ||
```python | ||
class ExpertChoiceTopKRouter(nn.Module): | ||
def __init__(self, hidden_dim, num_experts): | ||
self.gate = nn.Linear(hidden_dim, num_experts) | ||
self.tokens_per_expert = tokens_per_expert | ||
|
||
def forward(self, x, use_sigmoid=False): | ||
''' | ||
input: | ||
x: shape [bs*slen, hidden_dim] | ||
outputs: | ||
routed_input: selected tokens | ||
shape [num_experts*tokens_per_expert, hidden_dim] | ||
token_indices: selected token indices | ||
num_local_tokens_per_expert: None | ||
''' | ||
# scores shape [num_experts, bs*slen] | ||
scores = self.gate(x).transpose(0,1) | ||
if use_sigmoid: | ||
scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) | ||
else: | ||
scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype) | ||
# router scores/indices shape [num_experts, tokens_per_expert] | ||
top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline a bit but we should make sure there isn't any information leakage happening here for decoder models. Specifically, this line seems a bit fishy to me. It seems like each expert is looking across all tokens and choosing its top tokens_per_expert based on the entire sequence, right? |
||
|
||
# apply the token preprocess function and then run experts forward | ||
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, D) | ||
# routed input shape [num_experts*tokens_per_expert, hidden_dim] | ||
routed_input = torch.gather(x, dim=0, index=token_indices) | ||
routed_input = routed_input * top_scores.reshape(-1, 1) | ||
return routed_input, token_indices, None, | ||
``` | ||
|
||
## Moe Layer | ||
An MOE layer consists of experts and routers. | ||
|
||
**Here's the proposed MoeLayer design in torchtune:** | ||
```python | ||
class MoeLayer(nn.Module): | ||
def __init__(self, router="token_choice"): | ||
self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) | ||
self.shared_expert = moe_expert(hidden_dim, model_dim) | ||
if router == "token_choice": | ||
self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) | ||
elif router == "expert_choice": | ||
self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) | ||
else: | ||
raise NotImplementedError("This router is not supported yet!") | ||
Comment on lines
+180
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a minor point, but I wouldn't configure router type using strings. Just pass as an nn.Module as that's more generic (then anyone with their own custom router matching the same signature can use out of the box) |
||
|
||
def forward(self, x, infernece=False): | ||
routed_input, token_indices, num_local_tokens_per_expert = self.router(x) | ||
# routed output shape [num_experts*tokens_per_expert, hidden_dim] for EC, [bs*slen*experts_per_token, hidden_dim] for TC | ||
routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert) | ||
|
||
# shared expert | ||
if use_shared_expert: | ||
out = self.shared_expert(x) | ||
else: | ||
out = torch.zeros_like(x) | ||
|
||
# add experts output | ||
out.data = scatter_add_( | ||
out.data, | ||
routed_output, | ||
selected_indices, | ||
) | ||
return out | ||
``` | ||
|
||
## Model builder | ||
Besides the above components: experts, routers, and MOE layers, we would need a model builder to pull all pieces together to form the Transformer decoder layer and then Transformer decoder: | ||
|
||
**Here's the proposed MOE model builder design in torchtune:** | ||
```python | ||
def moe(...) -> TransformerDecoder: | ||
# Build the decoder associated with the moe model. This includes | ||
# - Token embeddings | ||
# - num_layers number of TransfomerDecoderLayer block | ||
# - RMS Norm layer applied to the ouput of the transfomer | ||
# - Final projection into the token space' | ||
token_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
self_attn = MultiHeadAttention() | ||
moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would consider an |
||
norm = RMSNorm(dim=embed_dim) | ||
layer = TransformerSelfAttentionLayer(attn=self_attn, mlp=moe_layer, sa_norm=norm, mlp_norm=norm) | ||
output_proj = nn.Linear(embed_dim, vocab_size) | ||
return TransformerDecoder( | ||
tok_embeddings=tok_embeddings, | ||
layers=layer, | ||
num_layers=num_layers, | ||
max_seq_len=max_seq_len, | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
norm=RMSNorm(dim=embed_dim), | ||
output=output_proj, | ||
) | ||
``` | ||
|
||
**File changes for new modules/functions** | ||
``` | ||
torchtune/ | ||
modules/ | ||
moe/ | ||
moe_layers.py | ||
TokenChoiceTopKRouter() | ||
ExpertChoiceTopKRouter() | ||
MoeLayer() | ||
experts.py | ||
Experts() | ||
models/ | ||
moe/ | ||
_component_builders.py | ||
moe() | ||
moe_expert() | ||
moe_experts() | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming, I think this should be something like MoELinear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what's the reason? I was thinking having Experts, Routers, and MoELayers is more clear in terms of what each class is doing. Otherwise, there's no Experts class and this
MoELinear
is essentially Experts? I don't have any strong preferences, but just want to know why you think MoELinear is better?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It matches our own convention of defining FeedForward (singular) rather than a group of things like Experts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But here we are defining a group of experts/feedforward. Wouldn't
MoELinear
be confusing? It's more likeGroupedMoELinear
?