Skip to content

Commit 576bf2f

Browse files
committed
overlap psum with sc kernel
1 parent efb489e commit 576bf2f

File tree

4 files changed

+85
-9
lines changed

4 files changed

+85
-9
lines changed

tpu_inference/env_override.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# This prevents errors when trying to create CUDA streams on TPU hardware
88
# The issue was introduced by vllm-project/vllm#26440
99
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
10-
os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_use_tc_device_shape_on_sc=true"
10+
os.environ[
11+
"LIBTPU_INIT_ARGS"] = "--xla_tpu_use_tc_device_shape_on_sc=true --xla_tpu_scheduler_percent_shared_memory_limit=1000"
1112

1213
# Monkeypatch vLLM to avoid ImportError: cannot import name 'SamplingParams' from 'vllm'
1314
# in vllm/v1/... submodules due to circular imports or lazy loading failures.
@@ -22,4 +23,4 @@
2223
from vllm.sampling_params import RequestOutputKind
2324
vllm.RequestOutputKind = RequestOutputKind
2425
except ImportError:
25-
pass
26+
pass

tpu_inference/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
USE_BATCHED_RPA_KERNEL: bool = False
3838
SC_KERNEL_THRESHOLD: int = 8192
3939
SC_KERNEL_COL_CHUNK_SIZE: int = 3072
40+
SC_PSUM_NUM_CHUNKS: int = 4
4041

4142

4243
def env_with_choices(
@@ -210,6 +211,8 @@ def _get_bool_env() -> bool:
210211
lambda: int(os.getenv("SC_KERNEL_THRESHOLD") or "8192"),
211212
"SC_KERNEL_COL_CHUNK_SIZE":
212213
lambda: int(os.getenv("SC_KERNEL_COL_CHUNK_SIZE") or "3072"),
214+
"SC_PSUM_NUM_CHUNKS":
215+
lambda: int(os.getenv("SC_PSUM_NUM_CHUNKS") or "4"),
213216
}
214217

215218

tpu_inference/layers/common/fused_moe_gmm.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def moe_gmm_local(
9191
parallelism: Literal["tp", "ep"],
9292
sc_kernel_threshold: int,
9393
sc_kernel_col_chunk_size: int,
94+
sc_psum_num_chunks: int,
9495
) -> jax.Array:
9596
"""Main MoE logic on a local shard can run in TP or EP mode.
9697
@@ -128,6 +129,9 @@ def moe_gmm_local(
128129
group_offset + local_group_size,
129130
)[topk_argsort_revert_indices]
130131

132+
reduction_axis = (ShardingAxisName.MLP_TENSOR
133+
if parallelism == "tp" else ShardingAxisName.EXPERT)
134+
131135
if gather_reduce_sc.is_supported_by_sc_gather_reduce(
132136
gmm1_res.shape[0], sc_kernel_threshold):
133137
gmm2_res = gmm_wrapper(gmm1_res,
@@ -145,13 +149,71 @@ def moe_gmm_local(
145149
inds = topk_argsort_revert_indices
146150
topk_weights = topk_weights.flatten().reshape(-1, 128)
147151

148-
token_hidden = gather_reduce_sc.sc_gather_reduce(
152+
chunk_size = gmm2_res.shape[0] // sc_psum_num_chunks
153+
inds_reshaped = inds.reshape(sc_psum_num_chunks, chunk_size)
154+
topk_weights_reshaped = topk_weights.reshape(sc_psum_num_chunks,
155+
chunk_size // 128, 128)
156+
157+
# Pre-allocate output buffer to save memory and avoids list accumulation
158+
# The shape is (inds.shape[0] // 8, hidden_size)
159+
token_hidden = jnp.zeros((inds.shape[0] // 8, gmm2_res.shape[-1]),
160+
dtype=jnp.bfloat16)
161+
162+
# Prologue: Execute the first kernel chunk
163+
chunk_out_prev = gather_reduce_sc.sc_gather_reduce(
149164
op=gmm2_res,
150-
idx=inds,
165+
idx=inds_reshaped[0],
151166
reduce_group_size=topk,
152-
topk_weights=topk_weights,
167+
topk_weights=topk_weights_reshaped[0],
153168
col_chunk_size=sc_kernel_col_chunk_size,
154169
)
170+
171+
chunk_out_reduced = None
172+
173+
for i in range(1, sc_psum_num_chunks):
174+
weights_chunk = topk_weights_reshaped[i]
175+
176+
# Optimization barrier to ensure SC_i and TC_{i-1} start in parallel
177+
if i == 1:
178+
idx_chunk_barriered, chunk_out_prev_barriered = jax.lax.optimization_barrier(
179+
(inds_reshaped[i], chunk_out_prev))
180+
else:
181+
idx_chunk_barriered, chunk_out_prev_barriered, _ = jax.lax.optimization_barrier(
182+
(inds_reshaped[i], chunk_out_prev, chunk_out_reduced))
183+
184+
# Start SC kernel using the barriered index
185+
chunk_out = gather_reduce_sc.sc_gather_reduce(
186+
op=gmm2_res,
187+
idx=idx_chunk_barriered,
188+
reduce_group_size=topk,
189+
topk_weights=weights_chunk,
190+
col_chunk_size=sc_kernel_col_chunk_size,
191+
)
192+
193+
# psum on the previous chunk output
194+
chunk_out_reduced = jax.lax.psum(chunk_out_prev_barriered,
195+
axis_name=reduction_axis)
196+
197+
# In-place update of the pre-allocated buffer
198+
token_hidden = jax.lax.dynamic_update_slice(
199+
token_hidden, chunk_out_reduced,
200+
((i - 1) * (chunk_size // 8), 0))
201+
202+
chunk_out_prev = chunk_out
203+
204+
# Epilogue: Perform psum on the last kernel output
205+
if sc_psum_num_chunks > 1:
206+
chunk_out_prev_barriered, _ = jax.lax.optimization_barrier(
207+
(chunk_out_prev, chunk_out_reduced))
208+
else:
209+
chunk_out_prev_barriered = jax.lax.optimization_barrier(
210+
(chunk_out_prev, ))[0]
211+
212+
chunk_out_reduced_final = jax.lax.psum(chunk_out_prev_barriered,
213+
axis_name=reduction_axis)
214+
token_hidden = jax.lax.dynamic_update_slice(
215+
token_hidden, chunk_out_reduced_final,
216+
((sc_psum_num_chunks - 1) * (chunk_size // 8), 0))
155217
else:
156218
gmm2_res = gmm_wrapper(gmm1_res,
157219
w2,
@@ -173,10 +235,11 @@ def moe_gmm_local(
173235

174236
token_hidden = token_topk_hidden.sum(axis=-2)
175237

176-
reduction_axis = (ShardingAxisName.MLP_TENSOR
177-
if parallelism == "tp" else ShardingAxisName.EXPERT)
178-
# Then global reduction on all ranks for all tokens and all experts
179-
return jax.lax.psum(token_hidden, axis_name=reduction_axis).astype(x.dtype)
238+
# Then global reduction on all ranks for all tokens and all experts
239+
token_hidden = jax.lax.psum(token_hidden,
240+
axis_name=reduction_axis).astype(x.dtype)
241+
242+
return token_hidden
180243

181244

182245
def tensor_parallel_gmm(
@@ -196,6 +259,7 @@ def tensor_parallel_gmm(
196259
mesh: Mesh,
197260
sc_kernel_threshold: int,
198261
sc_kernel_col_chunk_size: int,
262+
sc_psum_num_chunks: int,
199263
) -> jax.Array:
200264
data_p_spec = P(ShardingAxisName.MLP_DATA)
201265
group_offset = jnp.array([0])
@@ -221,6 +285,7 @@ def tensor_parallel_gmm(
221285
parallelism="tp",
222286
sc_kernel_threshold=sc_kernel_threshold,
223287
sc_kernel_col_chunk_size=sc_kernel_col_chunk_size,
288+
sc_psum_num_chunks=sc_psum_num_chunks,
224289
),
225290
mesh=mesh,
226291
in_specs=(
@@ -270,6 +335,7 @@ def expert_parallel_gmm(
270335
mesh: Mesh,
271336
sc_kernel_threshold: int,
272337
sc_kernel_col_chunk_size: int,
338+
sc_psum_num_chunks: int,
273339
) -> jax.Array:
274340
ep_size = get_mesh_shape_product(mesh, ShardingAxisName.EXPERT)
275341
ep_p_spec = P(ShardingAxisName.EXPERT)
@@ -291,6 +357,7 @@ def expert_parallel_gmm(
291357
parallelism="ep",
292358
sc_kernel_threshold=sc_kernel_threshold,
293359
sc_kernel_col_chunk_size=sc_kernel_col_chunk_size,
360+
sc_psum_num_chunks=sc_psum_num_chunks,
294361
),
295362
mesh=mesh,
296363
in_specs=(
@@ -332,6 +399,7 @@ def expert_parallel_gmm(
332399
"scoring_fn",
333400
"sc_kernel_threshold",
334401
"sc_kernel_col_chunk_size",
402+
"sc_psum_num_chunks",
335403
))
336404
def fused_moe_func(
337405
hidden_states: jax.Array,
@@ -350,6 +418,7 @@ def fused_moe_func(
350418
scoring_fn: str,
351419
sc_kernel_threshold: int,
352420
sc_kernel_col_chunk_size: int,
421+
sc_psum_num_chunks: int,
353422
) -> jax.Array:
354423
"""Route tokens in hidden_states into each experts based on routing.
355424
@@ -441,6 +510,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
441510
mesh=mesh,
442511
sc_kernel_threshold=sc_kernel_threshold,
443512
sc_kernel_col_chunk_size=sc_kernel_col_chunk_size,
513+
sc_psum_num_chunks=sc_psum_num_chunks,
444514
)
445515
else:
446516
x = tensor_parallel_gmm(
@@ -459,6 +529,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
459529
mesh=mesh,
460530
sc_kernel_threshold=sc_kernel_threshold,
461531
sc_kernel_col_chunk_size=sc_kernel_col_chunk_size,
532+
sc_psum_num_chunks=sc_psum_num_chunks,
462533
)
463534

464535
return x[:num_tokens, :hidden_size]

tpu_inference/layers/common/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def moe_apply(
139139
scoring_fn=layer.scoring_func,
140140
sc_kernel_threshold=envs.SC_KERNEL_THRESHOLD,
141141
sc_kernel_col_chunk_size=envs.SC_KERNEL_COL_CHUNK_SIZE,
142+
sc_psum_num_chunks=envs.SC_PSUM_NUM_CHUNKS,
142143
)
143144
case MoEBackend.DENSE_MAT:
144145
# NOTE: circular import avoidance

0 commit comments

Comments
 (0)