Skip to content

Commit a60592f

Browse files
fix moe data parallel for v1 engine (#252)
* fix dp for v1 - remove DP padding support in v1 worker - add validation for DP implementation constraints in v1 worker - apply token mask to custom MOE kernel router logits - update default environment variables: - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode" - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True - fix DP metadata handling in forward context - add is_prefills field to RBLNFlashAttentionMetadata * fix test_rbln_envs.py - VLLM_RBLN_DP_IMPL should be padded_decode by default * fix DPMetadata for tokens mask - remove is_prefills field and related logic from DP metadata - fix get_tokens_mask() for non-DP case --------- Co-authored-by: rebel-jonghewk <142865404+rebel-jonghewk@users.noreply.github.com>
1 parent cae35d7 commit a60592f

6 files changed

Lines changed: 39 additions & 22 deletions

File tree

tests/torch_compile/common/test_rbln_envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def test_rbln_envs():
4949
), f"Expected VLLM_RBLN_DISABLE_MM to be False, \
5050
got {rbln_envs.VLLM_RBLN_DISABLE_MM}"
5151

52-
assert (rbln_envs.VLLM_RBLN_DP_IMPL == "dummy_prefill"
53-
), f"Expected VLLM_RBLN_DP_IMPL to be dummy_prefill, \
52+
assert (rbln_envs.VLLM_RBLN_DP_IMPL == "padded_decode"
53+
), f"Expected VLLM_RBLN_DP_IMPL to be padded_decode, \
5454
got {rbln_envs.VLLM_RBLN_DP_IMPL}"
5555

5656
assert (not rbln_envs.VLLM_RBLN_ENFORCE_MODEL_FP32

vllm_rbln/forward_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,6 @@ def make(
5454
# for v0 attention backends
5555
batchsize = attn_metadata.num_prefill_tokens + \
5656
attn_metadata.num_decode_tokens
57-
58-
disable_dp = dp_size == 1
59-
use_dummy_prefill = envs.VLLM_RBLN_DP_IMPL == "dummy_prefill"
60-
if (disable_dp or use_dummy_prefill) and \
61-
attn_metadata.num_decode_tokens > 0:
62-
max_pad = scheduler_config.max_num_seqs
6357
else:
6458
# for v1 attention backends or no attn_metadata
6559
batchsize = num_tokens

vllm_rbln/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,17 @@ def unquantized_fused_moe_method_rbln(
238238
return final_hidden_states.reshape(orig_shape)
239239

240240

241-
def _get_tokens_mask():
242-
num_tokens = \
241+
def get_tokens_mask(num_tokens: int, left=1.0, right=float('-inf')):
242+
num_tokens_across_dp = \
243243
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
244-
num_tokens = num_tokens.unsqueeze(1)
245-
max_pad = get_forward_context().dp_metadata.max_pads_across_dp
244+
num_tokens_across_dp = num_tokens_across_dp.unsqueeze(1)
245+
if num_tokens_across_dp.size(0) == 1:
246+
max_pad = num_tokens
247+
else:
248+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp
246249
pos = torch.arange(max_pad, dtype=torch.int32).unsqueeze(0) # [1, max_pad]
247-
tokens_mask = torch.where(pos < num_tokens, 1.0, 0.0) # [dp_size, max_pad]
250+
tokens_mask = torch.where(pos < num_tokens_across_dp, left,
251+
right) # [dp_size, max_pad]
248252
tokens_mask = tokens_mask.reshape(-1, 1) #[dp_size * max_pad, 1]
249253
return tokens_mask
250254

@@ -268,7 +272,7 @@ def get_masked_routing_weights(router_logits, top_k, renormalize, expert_map):
268272

269273
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
270274
if use_moe_tokens_mask:
271-
tokens_mask = _get_tokens_mask()
275+
tokens_mask = get_tokens_mask(router_logits.shape[0], 1.0, 0.0)
272276
selected_weights = selected_weights * tokens_mask
273277

274278
n_expert = router_logits.shape[1]
@@ -393,6 +397,11 @@ def unquantized_fused_optimize_moe_method_custom(
393397
expert_map_list = expert_map.tolist()
394398
expert_map_const = torch.tensor(expert_map_list, dtype=torch.int32)
395399

400+
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
401+
if use_moe_tokens_mask:
402+
tokens_mask = get_tokens_mask(num_tokens)
403+
router_logits = router_logits * tokens_mask
404+
396405
# optimum-rbln/src/optimum/rbln/transformers/models/qwen3_moe/
397406
# qwen3_moe_architecture.py
398407
final_hidden_states = torch.ops.rbln_custom_ops.custom_moe_glu(

vllm_rbln/rbln_envs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
VLLM_RBLN_USE_VLLM_MODEL: bool = False
2727
VLLM_RBLN_FLASH_CAUSAL_ATTN: bool = True
2828
VLLM_RBLN_DISABLE_MM: bool = False
29-
VLLM_RBLN_DP_IMPL: str = "dummy_prefill"
30-
VLLM_RBLN_USE_MOE_TOKENS_MASK: bool = False
29+
VLLM_RBLN_DP_IMPL: str = "padded_decode"
30+
VLLM_RBLN_USE_MOE_TOKENS_MASK: bool = True
3131
VLLM_RBLN_ENFORCE_MODEL_FP32: bool = False
3232
VLLM_RBLN_MOE_CUSTOM_KERNEL: bool = True
3333
VLLM_RBLN_MOE_USE_OPT_KERNEL: bool = False
@@ -41,8 +41,9 @@
4141
def get_dp_impl():
4242
dp_impl = os.environ.get("VLLM_RBLN_DP_IMPL")
4343
if dp_impl is None:
44-
return "dummy_prefill"
45-
# default is dummy_prefill
44+
return "padded_decode"
45+
# default is padded_decode
46+
# dummy_prefill will be deprecated in the future
4647
choices = set(["padded_decode", "dummy_prefill"])
4748
current_impl = dp_impl.lower()
4849
if current_impl not in choices:
@@ -90,8 +91,9 @@ def get_dp_impl():
9091
"VLLM_RBLN_DP_IMPL":
9192
get_dp_impl,
9293
# If true, it uses the tokens mask applied to moe expert kernel
93-
"VLLM_RBLN_USE_MOE_TOKENS_MASK": (lambda: os.environ.get(
94-
"VLLM_RBLN_USE_MOE_TOKENS_MASK", "False").lower() in ("true", "1")),
94+
"VLLM_RBLN_USE_MOE_TOKENS_MASK":
95+
(lambda: os.environ.get("VLLM_RBLN_USE_MOE_TOKENS_MASK", "True").lower() in
96+
("true", "1")),
9597
# enforce model data type into fp32 not model_config.dtype
9698
"VLLM_RBLN_ENFORCE_MODEL_FP32":
9799
(lambda: os.environ.get("VLLM_RBLN_ENFORCE_MODEL_FP32", "False").lower() in

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,8 +1255,8 @@ def _preprocess(
12551255
num_input_tokens = num_scheduled_tokens
12561256

12571257
# Padding for DP
1258-
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
1259-
num_input_tokens += num_pad
1258+
# NOTE(RBLN): RBLN does not support DP padding
1259+
num_tokens_across_dp = None
12601260

12611261
# _prepare_inputs may reorder the batch, so we must gather multi
12621262
# modal outputs after that to ensure the correct order

vllm_rbln/v1/worker/rbln_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,18 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
241241
self.model_runner.initialize_kv_cache(kv_cache_config)
242242

243243
def compile_or_warm_up_model(self) -> None:
244+
if self.parallel_config.data_parallel_size > 1:
245+
if envs.VLLM_RBLN_DP_IMPL == "padded_decode":
246+
max_num_batched_tokens = \
247+
self.scheduler_config.max_num_batched_tokens
248+
max_num_seqs = self.scheduler_config.max_num_seqs
249+
# TODO: consider relaxing this constraint
250+
assert max_num_batched_tokens % max_num_seqs == 0, \
251+
"max_num_batched_tokens must be divisible by max_num_seqs"
252+
elif envs.VLLM_RBLN_DP_IMPL == "dummy_prefill":
253+
raise ValueError("dummy_prefill is not supported in v1 worker" \
254+
"and will be deprecated in the future")
255+
244256
if (self.model_config.enforce_eager or not envs.VLLM_RBLN_COMPILE_MODEL
245257
or not envs.VLLM_RBLN_ENABLE_WARM_UP):
246258
logger.warning("skipping compile_or_warm_up_model")

0 commit comments

Comments
 (0)