|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Callable, NamedTuple, Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn.functional as F |
| 11 | +from torch import nn |
| 12 | + |
| 13 | + |
| 14 | +class GroupedExperts(nn.Module): |
| 15 | + """This class implements the grouped experts layer used in Mixture of Experts. Each expert |
| 16 | + is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. |
| 17 | +
|
| 18 | + Args: |
| 19 | + dim_in (int): Input dimension. |
| 20 | + dim_out (int): Output dimension. |
| 21 | + num_experts (int): Number of experts in this grouped experts layer. Default is 1. |
| 22 | + swiglu (bool): Whether to use gated linear unit. Default is True. |
| 23 | + activation (nn.Module): Activation function to use. Default is F.silu. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + *, |
| 29 | + dim_in: int, |
| 30 | + dim_out: int, |
| 31 | + num_experts: int = 1, |
| 32 | + swiglu: bool = True, |
| 33 | + activation: Callable = F.silu, |
| 34 | + ): |
| 35 | + super().__init__() |
| 36 | + self.dim_in = dim_in |
| 37 | + self.num_experts = num_experts |
| 38 | + self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) |
| 39 | + self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) |
| 40 | + if swiglu: |
| 41 | + self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) |
| 42 | + self.act_fn = F.silu |
| 43 | + else: |
| 44 | + self.up_proj = None |
| 45 | + self.act_fn = activation |
| 46 | + |
| 47 | + def forward( |
| 48 | + self, |
| 49 | + x: torch.Tensor, |
| 50 | + ) -> torch.Tensor: |
| 51 | + """ |
| 52 | + Args: |
| 53 | + x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). |
| 54 | +
|
| 55 | + Returns: |
| 56 | + torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). |
| 57 | + """ |
| 58 | + # Expert Choice(EC) forward |
| 59 | + # x shape (num_experts, tokens_per_expert, dim_in) |
| 60 | + h = self.act_fn(torch.bmm(x, self.gate_proj)) |
| 61 | + if self.up_proj is not None: |
| 62 | + h = h * torch.bmm(x, self.up_proj) |
| 63 | + # out shape (num_experts, tokens_per_expert, dim_out) |
| 64 | + out = torch.bmm(h, self.down_proj) |
| 65 | + return out |
| 66 | + |
| 67 | + def init_weights(self, init_std: float): |
| 68 | + nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02) |
| 69 | + if self.up_proj is not None: |
| 70 | + nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std) |
| 71 | + nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std) |
| 72 | + |
| 73 | + |
| 74 | +class RouterOutput(NamedTuple): |
| 75 | + """Router output for Expert Choice routing. |
| 76 | +
|
| 77 | + routed_input (torch.Tensor): tokens grouped together by experts indices with shape |
| 78 | + ``(num_experts*tokens_per_expert, dim)`` for Expert Choice. |
| 79 | + token_indices (torch.Tensor): token indices for routed_input with shape |
| 80 | + ``(num_experts*tokens_per_expert, dim)`` for Expert Choice. |
| 81 | + """ |
| 82 | + |
| 83 | + routed_input: torch.Tensor |
| 84 | + token_indices: torch.Tensor |
| 85 | + |
| 86 | + |
| 87 | +class ExpertChoiceTopKRouter(nn.Module): |
| 88 | + """This class implements experts choice routing. Each experts will select it's top K tokens based on |
| 89 | + the router scores. Refer to more details in https://arxiv.org/abs/2202.09368 |
| 90 | +
|
| 91 | + Args: |
| 92 | + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). |
| 93 | + dim (int): Dimension of input tokens. |
| 94 | + num_experts (int): Number of experts in each moe layer. |
| 95 | + capacity_factor (float): Capacity factor determines how many tokens each expert can choose. |
| 96 | + expert capacity = (number of tokens * capacity factor) / number of experts. |
| 97 | + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. |
| 98 | + """ |
| 99 | + |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + *, |
| 103 | + gate: nn.Module, |
| 104 | + dim: int, |
| 105 | + num_experts: int, |
| 106 | + capacity_factor: float, |
| 107 | + use_sigmoid: bool = True, |
| 108 | + ): |
| 109 | + super().__init__() |
| 110 | + self.gate = gate |
| 111 | + self.dim = dim |
| 112 | + self.num_experts = num_experts |
| 113 | + self.capacity_factor = capacity_factor |
| 114 | + self.use_sigmoid = use_sigmoid |
| 115 | + |
| 116 | + def forward(self, x: torch.Tensor) -> RouterOutput: |
| 117 | + """ |
| 118 | + Args: |
| 119 | + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + routed_input (torch.Tensor): input tokens grouped together by experts indices with shape |
| 123 | + ``(num_experts*tokens_per_expert, dim)``. |
| 124 | + token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``. |
| 125 | + """ |
| 126 | + # scores shape (num_experts, bs*slen) |
| 127 | + scores = self.gate(x).transpose(0, 1) |
| 128 | + # By default, we perform sigmoid and softmax in float32 to avoid loss explosion. |
| 129 | + if self.use_sigmoid: |
| 130 | + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) |
| 131 | + else: |
| 132 | + scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype) |
| 133 | + tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts) |
| 134 | + tokens_per_expert += -tokens_per_expert % 8 |
| 135 | + # Take the smaller of tokens_per_expert and the number of tokens |
| 136 | + tokens_per_expert = min(tokens_per_expert, x.shape[0]) |
| 137 | + |
| 138 | + # top_scores shape (num_experts, tokens_per_expert) |
| 139 | + top_scores, selected_token_indices = torch.topk( |
| 140 | + scores, k=tokens_per_expert, dim=1 |
| 141 | + ) |
| 142 | + |
| 143 | + # token_indices shape (num_experts*tokens_per_expert, dim) |
| 144 | + token_indices = selected_token_indices.reshape(-1, 1).expand(-1, self.dim) |
| 145 | + # routed_input shape (num_experts*tokens_per_expert, dim) |
| 146 | + routed_input = torch.gather(x, dim=0, index=token_indices) |
| 147 | + routed_input = routed_input * top_scores.reshape(-1, 1) |
| 148 | + return RouterOutput( |
| 149 | + routed_input, |
| 150 | + token_indices, |
| 151 | + ) |
| 152 | + |
| 153 | + def init_weights(self, init_std: float): |
| 154 | + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) |
| 155 | + |
| 156 | + |
| 157 | +class MoE(nn.Module): |
| 158 | + """This class implements the moe layer which is Mixture of Experts. Mixture of Experts |
| 159 | + typically consists of a set of expert networks, alongside with a router, which directs input tokens |
| 160 | + to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. |
| 161 | +
|
| 162 | + Args: |
| 163 | + experts (nn.Module): experts module. |
| 164 | + router (nn.Module): router module. |
| 165 | + shared_expert (Optional[nn.Module]): shared expert module. Default is None. |
| 166 | + """ |
| 167 | + |
| 168 | + def __init__( |
| 169 | + self, |
| 170 | + *, |
| 171 | + experts: nn.Module, |
| 172 | + router: nn.Module, |
| 173 | + shared_expert: Optional[nn.Module] = None, |
| 174 | + ): |
| 175 | + super().__init__() |
| 176 | + self.experts = experts |
| 177 | + self.router = router |
| 178 | + self.shared_expert = shared_expert |
| 179 | + |
| 180 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 181 | + """ |
| 182 | + Args: |
| 183 | + x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``. |
| 184 | +
|
| 185 | + Returns: |
| 186 | + out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``. |
| 187 | + """ |
| 188 | + bz, slen, dim = x.shape |
| 189 | + # routed_input shape (num_experts*tokens_per_expert, dim) for EC |
| 190 | + routed_input, token_indices = self.router(x.reshape(bz * slen, dim)) |
| 191 | + |
| 192 | + # routed_input shape (num_experts, tokens_per_expert, dim_in) |
| 193 | + routed_input = routed_input.reshape(self.router.num_experts, -1, dim) |
| 194 | + # routed_output shape (num_experts, tokens_per_expert, dim_out) |
| 195 | + routed_output = self.experts(routed_input) |
| 196 | + # routed_output shape (num_experts*tokens_per_expert, dim_out) |
| 197 | + routed_output = routed_output.reshape(-1, dim) |
| 198 | + |
| 199 | + # shared expert |
| 200 | + if self.shared_expert is not None: |
| 201 | + out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape( |
| 202 | + bz * slen, dim |
| 203 | + ) |
| 204 | + else: |
| 205 | + out = torch.zeros_like(x.reshape(bz * slen, dim)) |
| 206 | + |
| 207 | + # add experts output |
| 208 | + # doing in in place might be faster |
| 209 | + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) |
| 210 | + out = out.reshape(bz, slen, dim) |
| 211 | + return out |
| 212 | + |
| 213 | + def init_weights(self, init_std: float): |
| 214 | + self.experts.init_weights(init_std) |
| 215 | + self.router.init_weights(init_std) |
| 216 | + if self.shared_expert is not None: |
| 217 | + self.shared_expert.init_weights(init_std) |
0 commit comments