Skip to content

Commit a99eff3

Browse files
rebel-ykchoirebel-jangysrebel-yhboorebel-myeongbo
authored
fix: change MoE combine (#438)
Co-authored-by: Jangys <jangys@rebellions.ai> Co-authored-by: yhboo <yhboo@rebellions.ai> Co-authored-by: Myeongbo Shim <myeongbo.shim@rebellions.ai> Co-authored-by: Jang Yeongsang <122958878+rebel-jangys@users.noreply.github.com>
1 parent 853cf33 commit a99eff3

3 files changed

Lines changed: 31 additions & 11 deletions

File tree

vllm_rbln/model_executor/layers/fused_moe/layer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,27 @@ def fused_moe_forward_rbln(
431431

432432
if self.dp_size > 1:
433433
# output all_reduce == dp all_reduce + tp all_reduce
434-
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
435-
hidden_shape_dp = (-1, 1, org_hidden_shape[-1])
436-
final_hidden_states = all_hidden_states.reshape(hidden_shape_dp)
434+
if envs.VLLM_RBLN_MOE_REDUCE_SCATTER:
435+
hidden_shape_dp = (-1, 1, org_hidden_shape[-1])
436+
all_hidden_states = final_hidden_states.reshape(hidden_shape_dp)
437+
assert all_hidden_states.shape[0] % self.dp_size == 0
437438

438-
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
439-
num_tokens = org_hidden_shape[:-1].numel() # noqa: F841
440-
start = self.dp_rank * max_pad
441-
end = start + num_tokens
442-
final_hidden_states = final_hidden_states[start:end]
439+
hidden_states = get_dp_group().reduce_scatter(all_hidden_states, dim=0)
440+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
441+
assert hidden_states.shape[0] == max_pad
442+
443+
num_tokens = org_hidden_shape[:-1].numel() # noqa: F841
444+
final_hidden_states = hidden_states[:num_tokens]
445+
else:
446+
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
447+
hidden_shape_dp = (-1, 1, org_hidden_shape[-1])
448+
final_hidden_states = all_hidden_states.reshape(hidden_shape_dp)
449+
450+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
451+
num_tokens = org_hidden_shape[:-1].numel() # noqa: F841
452+
start = self.dp_rank * max_pad
453+
end = start + num_tokens
454+
final_hidden_states = final_hidden_states[start:end]
443455

444456
final_hidden_states = final_hidden_states.reshape(org_hidden_shape)
445457

vllm_rbln/model_executor/layers/quantization/mxfp4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,10 @@ def apply(
410410
expert_map_list = layer.expert_map.tolist()
411411
expert_map_const = torch.tensor(expert_map_list, dtype=torch.int32)
412412

413-
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
414413
tokens_mask = None
414+
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
415415
if use_moe_tokens_mask:
416-
tokens_mask = get_tokens_mask(num_tokens, 0.0, float("-inf"))
417-
router_logits = router_logits + tokens_mask
416+
tokens_mask = get_tokens_mask(num_tokens)
418417

419418
final_hidden_states = torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
420419
hidden_states,
@@ -433,6 +432,7 @@ def apply(
433432
layer.top_k,
434433
layer.renormalize,
435434
expert_map_const,
435+
tokens_mask,
436436
)
437437
else:
438438
raise NotImplementedError(layer.activation)

vllm_rbln/rbln_envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
VLLM_RBLN_DECODE_BATCH_BUCKET_MANUAL_BUCKETS: list[int] = []
4747
VLLM_RBLN_USE_CUSTOM_KERNEL: bool = False
4848
VLLM_RBLN_AUTO_PORT: bool = True
49+
VLLM_RBLN_MOE_REDUCE_SCATTER: bool = False
4950

5051

5152
def get_dp_impl() -> str:
@@ -254,6 +255,13 @@ def get_decode_batch_bucket_manual_buckets() -> list[int]:
254255
os.environ.get("RBLN_USE_CUSTOM_KERNEL", "False").lower() in ("true", "1")
255256
)
256257
),
258+
# Use reduce_scatter instead of all_reduce in MoE combine phase
259+
"VLLM_RBLN_MOE_REDUCE_SCATTER": (
260+
lambda: (
261+
os.environ.get("VLLM_RBLN_MOE_REDUCE_SCATTER", "False").lower()
262+
in ("true", "1")
263+
)
264+
),
257265
"VLLM_RBLN_PROFILER": (
258266
lambda: os.environ.get("RBLN_PROFILER", "False").lower() in ("true", "1")
259267
),

0 commit comments

Comments
 (0)