-
Notifications
You must be signed in to change notification settings - Fork 123
Expand file tree
/
Copy pathfused_moe_gmm.py
More file actions
429 lines (382 loc) · 13.7 KB
/
fused_moe_gmm.py
File metadata and controls
429 lines (382 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Literal
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from tpu_inference.kernels.megablox.gmm import gmm
from tpu_inference.kernels.megablox.gmm_v2 import (gmm_v2,
is_supported_by_gmm_v2)
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.utils import get_mesh_shape_product
def apply_scoring_fn(scoring_fn: str, x: jax.Array) -> jax.Array:
match scoring_fn:
case "softmax":
return jax.nn.softmax(x, axis=-1)
case "sigmoid":
return jax.nn.sigmoid(x)
case _:
raise NotImplementedError(
f"FusedMoE does not support {scoring_fn} scoring function")
def apply_act_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
match activation:
case "silu":
return jax.nn.silu(x1) * x2
case "swigluoai":
return _swigluoai(x1, x2)
case _:
raise NotImplementedError(
f"FusedMoE does not support {activation} activation function")
def _swigluoai(x1: jax.Array,
x2: jax.Array,
alpha=1.702,
limit=7.0) -> jax.Array:
x1 = jnp.clip(x1, a_max=limit)
x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
gated_activation = x1 * jax.nn.sigmoid(alpha * x1)
return gated_activation * (x2 + 1)
def gmm_wrapper(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset,
last_gmm):
if is_supported_by_gmm_v2(rhs_scale):
gmm_res = gmm_v2(
lhs=lhs,
rhs=rhs,
rhs_scale=rhs_scale,
rhs_bias=rhs_bias,
group_sizes=group_sizes,
group_offset=group_offset[0],
# If it's last gmm, we need to zero out unvisited rows because it would
# cause numeric error during final reduce if the rows are unitialized.
zero_initialize=last_gmm,
)
else:
gmm_res = gmm(
lhs=lhs,
rhs=rhs,
rhs_scale=rhs_scale,
rhs_bias=rhs_bias,
group_sizes=group_sizes,
preferred_element_type=lhs.dtype,
tiling=None,
group_offset=group_offset[0],
)
return gmm_res
def moe_gmm_local(
x: jax.Array,
w1: jax.Array,
w1_scale: jax.Array | None,
w1_bias: jax.Array | None,
w2: jax.Array,
w2_scale: jax.Array | None,
w2_bias: jax.Array | None,
group_sizes: jax.Array,
group_offset: jax.Array,
topk_argsort_revert_indices: jax.Array,
topk_weights: jax.Array,
*,
activation: str,
topk: int,
parallelism: Literal["tp", "ep"],
) -> jax.Array:
""" Main MoE logic on a local shard can run in TP or EP mode.
Set parallelism for "tp" or "ep"
"""
assert parallelism in ["tp", "ep"]
# GMM1 computes x @ (W_up | W_gate) tegether and then split out to apply activation
# to the gate result
gmm1_res_gate_up = gmm_wrapper(x, w1, w1_scale, w1_bias, group_sizes,
group_offset, False)
gmm1_res_gate, gmm1_res_up = jnp.split(gmm1_res_gate_up, 2, -1)
gmm1_res = apply_act_fn(activation, gmm1_res_gate, gmm1_res_up)
# When the parallelism is TP since w2_bias is not sharded, we should only apply bias
# once, not applying to every shard. So we set w2_bias to 0 to all shards other than
# shard 0. For EP, it is not needed since bias is sharded on leading expert axis.
if parallelism == "tp" and w2_bias is not None:
shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
w2_bias = jnp.where(shard_id == 0, w2_bias, 0)
gmm2_res = gmm_wrapper(gmm1_res, w2, w2_scale, w2_bias, group_sizes,
group_offset, True)
# First run local reduction on topk experts owned by the rank for all tokens
token_topk_hidden = gmm2_res[topk_argsort_revert_indices].reshape(
(-1, topk, gmm2_res.shape[-1]))
token_topk_hidden = token_topk_hidden * jnp.expand_dims(topk_weights,
axis=-1)
token_hidden = token_topk_hidden.sum(axis=-2)
reduction_axis = (ShardingAxisName.MLP_TENSOR
if parallelism == "tp" else ShardingAxisName.EXPERT)
# Then global reduction on all ranks for all tokens and all experts
return jax.lax.psum(token_hidden, axis_name=reduction_axis)
def tensor_parallel_gmm(
x: jax.Array,
w1: jax.Array,
w1_scale: jax.Array | None,
w1_bias: jax.Array | None,
w2: jax.Array,
w2_scale: jax.Array | None,
w2_bias: jax.Array | None,
group_sizes: jax.Array,
topk_argsort_revert_indices: jax.Array,
topk_weights: jax.Array,
*,
activation: str,
topk: int,
mesh: Mesh,
) -> jax.Array:
data_p_spec = P(ShardingAxisName.MLP_DATA)
group_offset = jnp.array([0])
w1_spec = P(None, None, ShardingAxisName.MLP_TENSOR)
w2_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
w1_scale_spec = (None if w1_scale is None else P(
None, None, None, ShardingAxisName.MLP_TENSOR))
w1_bias_spec = (None if w1_bias is None else P(
None, None, ShardingAxisName.MLP_TENSOR))
num_blocks = 1 if w2_scale is None else w2_scale.shape[1]
w2_scale_spec = (None if num_blocks == 1 else P(
None, ShardingAxisName.MLP_TENSOR, None, None))
w2_bias_spec = None if w2_bias is None else P(None, None, None)
return jax.shard_map(
functools.partial(
moe_gmm_local,
activation=activation,
topk=topk,
parallelism="tp",
),
mesh=mesh,
in_specs=(
data_p_spec,
w1_spec,
w1_scale_spec,
w1_bias_spec,
w2_spec,
w2_scale_spec,
w2_bias_spec,
data_p_spec,
data_p_spec,
data_p_spec,
data_p_spec,
),
out_specs=(data_p_spec),
check_vma=False,
)(
x,
w1,
w1_scale,
w1_bias,
w2,
w2_scale,
w2_bias,
group_sizes,
group_offset,
topk_argsort_revert_indices,
topk_weights,
)
def expert_parallel_gmm(
x: jax.Array,
w1: jax.Array,
w1_scale: jax.Array | None,
w1_bias: jax.Array | None,
w2: jax.Array,
w2_scale: jax.Array | None,
w2_bias: jax.Array | None,
group_sizes: jax.Array,
topk_argsort_revert_indices: jax.Array,
topk_weights: jax.Array,
*,
activation: str,
topk: int,
mesh: Mesh,
) -> jax.Array:
ep_size = get_mesh_shape_product(mesh, ShardingAxisName.EXPERT)
ep_p_spec = P(ShardingAxisName.EXPERT)
data_p_spec = P(ShardingAxisName.MLP_DATA)
num_experts = w1.shape[0]
num_experts_per_shard = num_experts // ep_size
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
w1_scale_spec = None if w1_scale is None else ep_p_spec
w1_bias_spec = None if w1_bias is None else ep_p_spec
w2_scale_spec = None if w2_scale is None else ep_p_spec
w2_bias_spec = None if w2_bias is None else ep_p_spec
return jax.shard_map(
functools.partial(
moe_gmm_local,
activation=activation,
topk=topk,
parallelism="ep",
),
mesh=mesh,
in_specs=(
data_p_spec,
ep_p_spec,
w1_scale_spec,
w1_bias_spec,
ep_p_spec,
w2_scale_spec,
w2_bias_spec,
data_p_spec,
ep_p_spec,
data_p_spec,
data_p_spec,
),
out_specs=(data_p_spec),
check_vma=False,
)(
x,
w1,
w1_scale,
w1_bias,
w2,
w2_scale,
w2_bias,
group_sizes,
group_offset,
topk_argsort_revert_indices,
topk_weights,
)
@jax.jit(static_argnames=(
"topk",
"renormalize",
"mesh",
"use_ep",
"activation",
"scoring_fn",
))
def fused_moe_func(
hidden_states: jax.Array,
w1: jax.Array,
w2: jax.Array,
w1_scale: jax.Array | None,
w2_scale: jax.Array | None,
w1_bias: jax.Array | None,
w2_bias: jax.Array | None,
gating_output: jax.Array,
topk: int,
renormalize: bool,
mesh: Mesh,
use_ep: bool,
activation: str,
scoring_fn: str,
topk_weights: jax.Array | None = None,
topk_indices: jax.Array | None = None,
) -> jax.Array:
"""Route tokens in hidden_states into each experts based on routing.
Args:
hidden_states: [num_tokens, hidden_size]
w1: first moe weights [num_experts, hidden_size, intermediate_size * 2]
w2: second moe weights [num_experts, intermediate_size, hidden_size]
w1_scale: w1 scale [num_experts, num_blocks, 1, intermediate_size * 2]
w2_scale: w2 scale [num_experts, num_blocks, 1, hidden_size]
w1_bias: optional bias of w1 [num_experts, 1, intermediate_size * 2]
w2_bias: optional bias of w2 [num_experts, 1, hidden_size]
gating_output: routing information of tokens [num_tokens, num_experts]
topk: number of experts to choose per token.
renormalize: normalize gating_output.
mesh: mesh to perform moe.
use_ep: use expert parallelism.
activation: activation function to perform on the output of w1.
scoring_fn: scoring function to apply on gating_output.
topk_weights: pre-calculated expert weights.
topk_indices: pre-calculated expert indices.
Returns:
Output of moe operation [num_tokens, hidden_size]
"""
num_tokens, hidden_size = hidden_states.shape
global_num_experts, padded_hidden_size, _ = w1.shape
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
if topk_weights is None or topk_indices is None:
assert gating_output is not None
assert gating_output.shape == (num_tokens, global_num_experts)
topk_weights = apply_scoring_fn(scoring_fn, gating_output)
# All-gather topk weights for attention dp
topk_weights = jax.lax.with_sharding_constraint(
topk_weights,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(axis=-1,
keepdims=True)
else:
topk_weights = jax.lax.with_sharding_constraint(
topk_weights,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_indices = jax.lax.with_sharding_constraint(
topk_indices,
NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
topk_weights = topk_weights.astype(dtype)
def _process_tokens_locally(hidden_states_local, topk_indices_local):
num_tokens_local = hidden_states_local.shape[0]
topk_indices_flat = topk_indices_local.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
token_indices = jnp.arange(num_tokens_local,
dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
x = hidden_states_local[token_indices_sorted]
# Below one_hot is equivalent to jnp.bincount(topk_indices_flat,
# length=global_num_experts) but is more performant.
group_sizes_local = jax.nn.one_hot(topk_indices_flat,
global_num_experts,
dtype=jnp.int32).sum(axis=0)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
return x, group_sizes_local, topk_argsort_revert_indices
x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
_process_tokens_locally,
mesh=mesh,
in_specs=(
P(ShardingAxisName.MLP_DATA, None),
P(ShardingAxisName.MLP_DATA, None),
),
out_specs=(
P(ShardingAxisName.MLP_DATA, None),
P(ShardingAxisName.MLP_DATA),
P(ShardingAxisName.MLP_DATA),
),
)(hidden_states, topk_indices)
x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
if use_ep:
x = expert_parallel_gmm(
x,
w1,
w1_scale,
w1_bias,
w2,
w2_scale,
w2_bias,
group_sizes,
topk_argsort_revert_indices,
topk_weights,
activation=activation,
topk=topk,
mesh=mesh,
)
else:
x = tensor_parallel_gmm(
x,
w1,
w1_scale,
w1_bias,
w2,
w2_scale,
w2_bias,
group_sizes,
topk_argsort_revert_indices,
topk_weights,
activation=activation,
topk=topk,
mesh=mesh,
)
return x[:num_tokens, :hidden_size]