Skip to content

Commit 1b8c35e

Browse files
committed
add specialized MoE decode optimization for DP
- implement specialized decode path that uses optimized padding when all requests are in decode stage - add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable specialized handling for decode-only batches in MoE models - refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor to differentiate speicalized decode and normal decode - add num_padded_tokens parameter to RBLNDPMetadata.make() and _set_forward_context() - add specialized decode path to batch bucketing
1 parent 608cc55 commit 1b8c35e

4 files changed

Lines changed: 256 additions & 138 deletions

File tree

vllm_rbln/forward_context.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
import torch.distributed as dist
2222
import vllm.forward_context as vfc
23-
from vllm.config import CUDAGraphMode, VllmConfig
23+
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
2424
from vllm.forward_context import (BatchDescriptor, DPMetadata,
2525
batchsize_logging_interval,
2626
create_forward_context,
@@ -35,7 +35,7 @@
3535

3636
@dataclass
3737
class RBLNDPMetadata(DPMetadata):
38-
max_pads_across_dp: int = 0
38+
max_pads_across_dp: torch.Tensor | None = None
3939

4040
@staticmethod
4141
def num_tokens_across_dp(num_tokens: int, dp_size: int,
@@ -53,26 +53,66 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
5353
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
5454
return num_tokens_tensor
5555

56+
@staticmethod
57+
def num_tokens_across_dp_with_max_decode_tokens(
58+
num_tokens: int, dp_size: int, dp_rank: int,
59+
is_prefill: bool) -> tuple[torch.Tensor, int | None]:
60+
pad_flag = 1 << 16
61+
pad_mask = pad_flag - 1
62+
assert num_tokens < pad_flag, \
63+
"num_tokens should be less than pad_flag"
64+
65+
if is_prefill:
66+
num_tokens |= pad_flag
67+
68+
tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp(
69+
num_tokens, dp_size, dp_rank)
70+
max_across_dp = torch.max(tokens_across_dp_cpu).item()
71+
72+
if is_prefill or max_across_dp > pad_flag:
73+
mask_tensor = torch.tensor([pad_mask] * dp_size,
74+
device="cpu",
75+
dtype=torch.int32)
76+
num_tokens_across_dp_cpu = tokens_across_dp_cpu & mask_tensor
77+
max_across_dp = None
78+
else:
79+
num_tokens_across_dp_cpu = tokens_across_dp_cpu
80+
81+
return num_tokens_across_dp_cpu, max_across_dp
82+
5683
@staticmethod
5784
def make(
58-
vllm_config: VllmConfig,
85+
parallel_config: ParallelConfig,
5986
num_tokens: int,
87+
num_tokens_across_dp: torch.Tensor | None = None,
88+
num_padded_tokens: int | None = None,
6089
) -> "RBLNDPMetadata":
61-
parallel_config = vllm_config.parallel_config
6290
dp_size = parallel_config.data_parallel_size
63-
dp_rank = parallel_config.data_parallel_rank
64-
65-
scheduler_config = vllm_config.scheduler_config
66-
max_pad = scheduler_config.max_num_batched_tokens
67-
batchsize = num_tokens
6891

69-
num_tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp(
70-
batchsize, dp_size, dp_rank)
71-
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
92+
if dp_size > 1:
93+
assert num_tokens_across_dp is not None, \
94+
"num_tokens_across_dp should be applied for DP case"
95+
assert num_padded_tokens is not None, \
96+
"num_padded_tokens should be applied for DP case"
97+
num_tokens_across_dp_cpu = num_tokens_across_dp
98+
max_pad = num_padded_tokens
99+
100+
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
101+
max_pads_across_dp = torch.empty(max_pad, device="cpu")
102+
else:
103+
assert num_tokens_across_dp is None, \
104+
"num_tokens_across_dp should not be applied for non-DP case"
105+
assert num_padded_tokens is None, \
106+
"num_padded_tokens should not be applied for non-DP case"
107+
num_tokens_across_dp_cpu = torch.tensor([num_tokens],
108+
device="cpu",
109+
dtype=torch.int32)
110+
max_tokens_across_dp_cpu = num_tokens
111+
max_pads_across_dp = None
72112

73113
return RBLNDPMetadata(max_tokens_across_dp_cpu,
74114
num_tokens_across_dp_cpu,
75-
max_pads_across_dp=max_pad)
115+
max_pads_across_dp=max_pads_across_dp)
76116

77117

78118
@contextmanager
@@ -85,6 +125,7 @@ def _set_forward_context(
85125
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
86126
batch_descriptor: BatchDescriptor | None = None,
87127
ubatch_slices: UBatchSlices | None = None,
128+
num_padded_tokens: int | None = None,
88129
):
89130
"""A context manager that stores the current forward context,
90131
can be attention metadata, etc.
@@ -99,7 +140,10 @@ def _set_forward_context(
99140
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
100141
if (enable_dp or use_moe_tokens_mask) and (attn_metadata is not None
101142
or num_tokens is not None):
102-
dp_metadata = RBLNDPMetadata.make(vllm_config, num_tokens or 0)
143+
dp_metadata = RBLNDPMetadata.make(vllm_config.parallel_config,
144+
num_tokens or 0,
145+
num_tokens_across_dp,
146+
num_padded_tokens)
103147

104148
forward_context = create_forward_context(
105149
attn_metadata,

vllm_rbln/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def get_tokens_mask(num_tokens: int, left=1.0, right=float('-inf')):
245245
if num_tokens_across_dp.size(0) == 1:
246246
max_pad = num_tokens
247247
else:
248-
max_pad = get_forward_context().dp_metadata.max_pads_across_dp
248+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
249249
pos = torch.arange(max_pad, dtype=torch.int32).unsqueeze(0) # [1, max_pad]
250250
tokens_mask = torch.where(pos < num_tokens_across_dp, left,
251251
right) # [dp_size, max_pad]
@@ -464,7 +464,7 @@ def fused_moe_forward_rbln(self, hidden_states: torch.Tensor,
464464
hidden_shape_dp = (-1, 1, org_hidden_shape[-1])
465465
final_hidden_states = all_hidden_states.reshape(hidden_shape_dp)
466466

467-
max_pad = get_forward_context().dp_metadata.max_pads_across_dp
467+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
468468
num_tokens = org_hidden_shape[:-1].numel() # noqa: F841
469469
start = self.dp_rank * max_pad
470470
end = start + num_tokens
@@ -483,7 +483,7 @@ def fused_moe_naive_multicast_rbln(self, x: torch.Tensor):
483483
# assert len(x.shape) == 3
484484

485485
x = x.reshape(1, -1, x.size(-1))
486-
max_pad = get_forward_context().dp_metadata.max_pads_across_dp
486+
max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0]
487487
num_tokens = x.size(1)
488488
num_repeat = max_pad // num_tokens
489489
# TODO: evaluate various padding approaches

vllm_rbln/rbln_envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
VLLM_RBLN_SAMPLER: bool = True
2525
VLLM_RBLN_ENABLE_WARM_UP: bool = True
2626
VLLM_RBLN_USE_VLLM_MODEL: bool = False
27+
VLLM_RBLN_SPECIALIZE_MOE_DECODE: bool = True
2728
VLLM_RBLN_FLASH_CAUSAL_ATTN: bool = True
2829
VLLM_RBLN_BATCH_ATTN_OPT: bool = False
2930
VLLM_RBLN_DISABLE_MM: bool = False
@@ -102,6 +103,9 @@ def get_dp_impl():
102103
"VLLM_RBLN_USE_MOE_TOKENS_MASK":
103104
(lambda: os.environ.get("VLLM_RBLN_USE_MOE_TOKENS_MASK", "True").lower() in
104105
("true", "1")),
106+
# If true, it specializes the cases where all instances are at decode stage
107+
"VLLM_RBLN_SPECIALIZE_MOE_DECODE": (lambda: os.environ.get(
108+
"VLLM_RBLN_SPECIALIZE_MOE_DECODE", "True").lower() in ("true", "1")),
105109
# enforce model data type into fp32 not model_config.dtype
106110
"VLLM_RBLN_ENFORCE_MODEL_FP32":
107111
(lambda: os.environ.get("VLLM_RBLN_ENFORCE_MODEL_FP32", "False").lower() in

0 commit comments

Comments
 (0)