@@ -91,6 +91,7 @@ def moe_gmm_local(
9191 parallelism : Literal ["tp" , "ep" ],
9292 sc_kernel_threshold : int ,
9393 sc_kernel_col_chunk_size : int ,
94+ sc_psum_num_chunks : int ,
9495) -> jax .Array :
9596 """Main MoE logic on a local shard can run in TP or EP mode.
9697
@@ -128,6 +129,9 @@ def moe_gmm_local(
128129 group_offset + local_group_size ,
129130 )[topk_argsort_revert_indices ]
130131
132+ reduction_axis = (ShardingAxisName .MLP_TENSOR
133+ if parallelism == "tp" else ShardingAxisName .EXPERT )
134+
131135 if gather_reduce_sc .is_supported_by_sc_gather_reduce (
132136 gmm1_res .shape [0 ], sc_kernel_threshold ):
133137 gmm2_res = gmm_wrapper (gmm1_res ,
@@ -145,13 +149,71 @@ def moe_gmm_local(
145149 inds = topk_argsort_revert_indices
146150 topk_weights = topk_weights .flatten ().reshape (- 1 , 128 )
147151
148- token_hidden = gather_reduce_sc .sc_gather_reduce (
152+ chunk_size = gmm2_res .shape [0 ] // sc_psum_num_chunks
153+ inds_reshaped = inds .reshape (sc_psum_num_chunks , chunk_size )
154+ topk_weights_reshaped = topk_weights .reshape (sc_psum_num_chunks ,
155+ chunk_size // 128 , 128 )
156+
157+ # Pre-allocate output buffer to save memory and avoids list accumulation
158+ # The shape is (inds.shape[0] // 8, hidden_size)
159+ token_hidden = jnp .zeros ((inds .shape [0 ] // 8 , gmm2_res .shape [- 1 ]),
160+ dtype = jnp .bfloat16 )
161+
162+ # Prologue: Execute the first kernel chunk
163+ chunk_out_prev = gather_reduce_sc .sc_gather_reduce (
149164 op = gmm2_res ,
150- idx = inds ,
165+ idx = inds_reshaped [ 0 ] ,
151166 reduce_group_size = topk ,
152- topk_weights = topk_weights ,
167+ topk_weights = topk_weights_reshaped [ 0 ] ,
153168 col_chunk_size = sc_kernel_col_chunk_size ,
154169 )
170+
171+ chunk_out_reduced = None
172+
173+ for i in range (1 , sc_psum_num_chunks ):
174+ weights_chunk = topk_weights_reshaped [i ]
175+
176+ # Optimization barrier to ensure SC_i and TC_{i-1} start in parallel
177+ if i == 1 :
178+ idx_chunk_barriered , chunk_out_prev_barriered = jax .lax .optimization_barrier (
179+ (inds_reshaped [i ], chunk_out_prev ))
180+ else :
181+ idx_chunk_barriered , chunk_out_prev_barriered , _ = jax .lax .optimization_barrier (
182+ (inds_reshaped [i ], chunk_out_prev , chunk_out_reduced ))
183+
184+ # Start SC kernel using the barriered index
185+ chunk_out = gather_reduce_sc .sc_gather_reduce (
186+ op = gmm2_res ,
187+ idx = idx_chunk_barriered ,
188+ reduce_group_size = topk ,
189+ topk_weights = weights_chunk ,
190+ col_chunk_size = sc_kernel_col_chunk_size ,
191+ )
192+
193+ # psum on the previous chunk output
194+ chunk_out_reduced = jax .lax .psum (chunk_out_prev_barriered ,
195+ axis_name = reduction_axis )
196+
197+ # In-place update of the pre-allocated buffer
198+ token_hidden = jax .lax .dynamic_update_slice (
199+ token_hidden , chunk_out_reduced ,
200+ ((i - 1 ) * (chunk_size // 8 ), 0 ))
201+
202+ chunk_out_prev = chunk_out
203+
204+ # Epilogue: Perform psum on the last kernel output
205+ if sc_psum_num_chunks > 1 :
206+ chunk_out_prev_barriered , _ = jax .lax .optimization_barrier (
207+ (chunk_out_prev , chunk_out_reduced ))
208+ else :
209+ chunk_out_prev_barriered = jax .lax .optimization_barrier (
210+ (chunk_out_prev , ))[0 ]
211+
212+ chunk_out_reduced_final = jax .lax .psum (chunk_out_prev_barriered ,
213+ axis_name = reduction_axis )
214+ token_hidden = jax .lax .dynamic_update_slice (
215+ token_hidden , chunk_out_reduced_final ,
216+ ((sc_psum_num_chunks - 1 ) * (chunk_size // 8 ), 0 ))
155217 else :
156218 gmm2_res = gmm_wrapper (gmm1_res ,
157219 w2 ,
@@ -173,10 +235,11 @@ def moe_gmm_local(
173235
174236 token_hidden = token_topk_hidden .sum (axis = - 2 )
175237
176- reduction_axis = (ShardingAxisName .MLP_TENSOR
177- if parallelism == "tp" else ShardingAxisName .EXPERT )
178- # Then global reduction on all ranks for all tokens and all experts
179- return jax .lax .psum (token_hidden , axis_name = reduction_axis ).astype (x .dtype )
238+ # Then global reduction on all ranks for all tokens and all experts
239+ token_hidden = jax .lax .psum (token_hidden ,
240+ axis_name = reduction_axis ).astype (x .dtype )
241+
242+ return token_hidden
180243
181244
182245def tensor_parallel_gmm (
@@ -196,6 +259,7 @@ def tensor_parallel_gmm(
196259 mesh : Mesh ,
197260 sc_kernel_threshold : int ,
198261 sc_kernel_col_chunk_size : int ,
262+ sc_psum_num_chunks : int ,
199263) -> jax .Array :
200264 data_p_spec = P (ShardingAxisName .MLP_DATA )
201265 group_offset = jnp .array ([0 ])
@@ -221,6 +285,7 @@ def tensor_parallel_gmm(
221285 parallelism = "tp" ,
222286 sc_kernel_threshold = sc_kernel_threshold ,
223287 sc_kernel_col_chunk_size = sc_kernel_col_chunk_size ,
288+ sc_psum_num_chunks = sc_psum_num_chunks ,
224289 ),
225290 mesh = mesh ,
226291 in_specs = (
@@ -270,6 +335,7 @@ def expert_parallel_gmm(
270335 mesh : Mesh ,
271336 sc_kernel_threshold : int ,
272337 sc_kernel_col_chunk_size : int ,
338+ sc_psum_num_chunks : int ,
273339) -> jax .Array :
274340 ep_size = get_mesh_shape_product (mesh , ShardingAxisName .EXPERT )
275341 ep_p_spec = P (ShardingAxisName .EXPERT )
@@ -291,6 +357,7 @@ def expert_parallel_gmm(
291357 parallelism = "ep" ,
292358 sc_kernel_threshold = sc_kernel_threshold ,
293359 sc_kernel_col_chunk_size = sc_kernel_col_chunk_size ,
360+ sc_psum_num_chunks = sc_psum_num_chunks ,
294361 ),
295362 mesh = mesh ,
296363 in_specs = (
@@ -332,6 +399,7 @@ def expert_parallel_gmm(
332399 "scoring_fn" ,
333400 "sc_kernel_threshold" ,
334401 "sc_kernel_col_chunk_size" ,
402+ "sc_psum_num_chunks" ,
335403))
336404def fused_moe_func (
337405 hidden_states : jax .Array ,
@@ -350,6 +418,7 @@ def fused_moe_func(
350418 scoring_fn : str ,
351419 sc_kernel_threshold : int ,
352420 sc_kernel_col_chunk_size : int ,
421+ sc_psum_num_chunks : int ,
353422) -> jax .Array :
354423 """Route tokens in hidden_states into each experts based on routing.
355424
@@ -441,6 +510,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
441510 mesh = mesh ,
442511 sc_kernel_threshold = sc_kernel_threshold ,
443512 sc_kernel_col_chunk_size = sc_kernel_col_chunk_size ,
513+ sc_psum_num_chunks = sc_psum_num_chunks ,
444514 )
445515 else :
446516 x = tensor_parallel_gmm (
@@ -459,6 +529,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
459529 mesh = mesh ,
460530 sc_kernel_threshold = sc_kernel_threshold ,
461531 sc_kernel_col_chunk_size = sc_kernel_col_chunk_size ,
532+ sc_psum_num_chunks = sc_psum_num_chunks ,
462533 )
463534
464535 return x [:num_tokens , :hidden_size ]
0 commit comments