Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
132e388
fix: bump up v0 moe dp implementation to v1
rebel-ykchoi Jan 6, 2026
b7ef9be
fix: mxfp4 kernel for model parallel
rebel-wonsubkim Jan 6, 2026
a432c0c
modify expert_map position
rebel-wonsubkim Jan 6, 2026
8c4f96e
fix gpt_oss tensor parallel all_reduce
rebel-wonsubkim Jan 7, 2026
5a25903
disable shared fused moe overlap for RBLN
rebel-wonsubkim Jan 8, 2026
e8760d5
reference torch impl for gpt-oss ops
rebel-jaehwang Jan 8, 2026
2c010d1
apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE
rebel-ykchoi Jan 8, 2026
046b830
adjust available dram size based on target arch
rebel-wonsubkim Jan 9, 2026
e38ea08
fix v1 dp online serving
rebel-ykchoi Jan 12, 2026
03d7161
add additional params for data_parallel.py script
rebel-wonsubkim Jan 13, 2026
cc1b9ca
fix calculation of maximum num blocks
rebel-wonsubkim Jan 16, 2026
f0d6056
fix: port v0.12 scheduler code
huijjj Jan 7, 2026
3644f37
fix: limit decode bs to (max num seqs // pp size)
huijjj Jan 7, 2026
3efcfc1
tmp: pad decode inputs to max_num_seqs // pp_size
huijjj Jan 7, 2026
1b1e53a
add: simple offline benchmark script
huijjj Jan 7, 2026
e831b96
fix DPMetadata for tokens mask
rebel-ykchoi Jan 19, 2026
aa93d30
fix dp with pp dummy run logic
rebel-ykchoi Jan 19, 2026
799fa32
fix max_num_blocks calculation
rebel-wonsubkim Jan 20, 2026
b67ff4a
add optimized batch attention kernel
rebel-wonsubkim Jan 20, 2026
d0c86e7
Merge branch 'gpt-oss-0.12-batched-attention' into dev-0.12
rebel-wonsubkim Jan 21, 2026
8431de6
resolve conflict between bucketing and dp
rebel-wonsubkim Jan 21, 2026
8e6bc84
fix num_runtimes
rebel-wonsubkim Jan 23, 2026
36e995e
pad seq_idx for batch attention
rebel-wonsubkim Jan 23, 2026
5054a91
fixed batched decode func call
rebel-jaehunryu Jan 21, 2026
a33092e
remove unused code
rebel-jaehunryu Jan 23, 2026
c16066a
fix up RBLN_METRICS
rebel-wonsubkim Jan 27, 2026
608cc55
fix typo
rebel-wonsubkim Jan 27, 2026
1b8c35e
add specialized MoE decode optimization for DP
rebel-ykchoi Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 37 additions & 57 deletions examples/experimental/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,55 +54,28 @@

os.environ['VLLM_TORCH_PROFILER_DIR'] = './profile'

hf_overrides_kw = {
"num_hidden_layers": 2,
}


def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, tp_size, enable_ep, vllm_use_v1):
dp_master_port, tp_size, enable_ep,
max_model_len, block_size, decode_batch, num_hidden_layers):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
# paralle_config.data_parallel_size = envs.sVLLM_DP_SIZE
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)

if not vllm_use_v1:
# in v0 worker, each process has distinct RBLN_DEVICES
rbln_devices = ""
if os.environ.get("VLLM_RBLN_TP_SIZE") is None:
rsd_size = 1
else:
rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE"))
rsd_tp_size = tp_size * rsd_size
start_index = local_dp_rank * rsd_tp_size
end_index = start_index + rsd_tp_size
for index in range(start_index, end_index):
if rbln_devices:
rbln_devices += ","
rbln_devices += str(index)

os.environ["RBLN_DEVICES"] = rbln_devices
else:
rbln_devices = os.environ.get("RBLN_DEVICES")

print(f"local RBLN_DEVICES = {rbln_devices}")
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
# engine processes.

# Sample prompts.
prompts = [
"Hello, my name is",
"The vLLM is",
"The president of the United States is",
"The future of AI is",
]
] * dp_size

# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
prompts_per_rank = (len(prompts) // dp_size) + 1
prompts_per_rank = (len(prompts) // dp_size)
start = global_dp_rank * prompts_per_rank
end = start + prompts_per_rank
prompts = prompts[start:end]
Expand All @@ -119,15 +92,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0.0)

if num_hidden_layers == 0:
hf_overrides_kw = None
else:
hf_overrides_kw = {
"num_hidden_layers": num_hidden_layers,
}

# Create an LLM.
llm = LLM(
model=model,
#hf_overrides=hf_overrides_kw,
max_model_len=8 * 1024,
block_size=1024,
hf_overrides=hf_overrides_kw,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=True,
max_num_batched_tokens=128,
max_num_seqs=1,
max_num_seqs=decode_batch,
trust_remote_code=True,
tensor_parallel_size=tp_size,
enable_expert_parallel=enable_ep,
Expand Down Expand Up @@ -166,6 +146,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
parser.add_argument('--ep',
action='store_true',
help="vLLM enable_expert_parallel")
parser.add_argument("--max-model-len",
type=int,
default=8192,
help="Max sequence length")
parser.add_argument("--block-size",
type=int,
default=4096,
help="KV cache block size")
parser.add_argument("--decode-batch",
type=int,
default=1,
help="decode batch size")
parser.add_argument("--num-hidden-layers",
type=int,
default=0,
help="num hidden layers")
parser.add_argument("--node-size",
type=int,
default=1,
Expand All @@ -189,6 +185,10 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
node_size = args.node_size
node_rank = args.node_rank
enable_ep = args.ep
max_model_len = args.max_model_len
block_size = args.block_size
decode_batch = args.decode_batch
num_hidden_layers = args.num_hidden_layers

if node_size == 1:
dp_master_ip = "127.0.0.1"
Expand All @@ -200,35 +200,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
dp_per_node = dp_size // node_size

vllm_use_v1 = (int(os.environ.get("VLLM_USE_V1", "0")) == 1)
if vllm_use_v1:
print("VLLM_USE_V1")
# in v1 worker, entire processes SHOULD have global RBLN_DEVICES
rbln_devices = ""
if os.environ.get("VLLM_RBLN_TP_SIZE") is None:
rsd_size = 1
else:
rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE"))
start_index = 0
end_index = start_index + tp_size * dp_size * rsd_size
for index in range(start_index, end_index):
if rbln_devices:
rbln_devices += ","
rbln_devices += str(index)

print(f"global RBLN_DEVICES = {rbln_devices}")
os.environ["RBLN_DEVICES"] = rbln_devices
else:
print("VLLM_USE_V0")

from multiprocessing import Process
procs = []
for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size, enable_ep, vllm_use_v1))
tp_size, enable_ep,
max_model_len, block_size, decode_batch, num_hidden_layers))
proc.start()
procs.append(proc)
exit_code = 0
Expand Down
1 change: 1 addition & 0 deletions vllm_rbln/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def register_ops():
import vllm_rbln.forward_context # noqa
import vllm_rbln.lora.layer # noqa
import vllm_rbln.model_executor.layers.fused_moe.layer # noqa
import vllm_rbln.model_executor.layers.fused_moe.shared_fused_moe # noqa
import vllm_rbln.model_executor.layers.logits_processor # noqa
import vllm_rbln.model_executor.layers.quantization.kernels.mixed_precision # noqa
import vllm_rbln.model_executor.layers.quantization.mxfp4 # noqa
Expand Down
172 changes: 106 additions & 66 deletions vllm_rbln/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any

import torch
import torch.distributed as dist
import vllm.forward_context as vfc
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import (BatchDescriptor, DPMetadata, ForwardContext,
batchsize_logging_interval, track_batchsize)
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.forward_context import (BatchDescriptor, DPMetadata,
batchsize_logging_interval,
create_forward_context,
override_forward_context, track_batchsize)
from vllm.v1.worker.ubatch_utils import UBatchSlices

import vllm_rbln.rbln_envs as envs
from vllm_rbln.logger import init_logger
Expand All @@ -31,99 +35,137 @@

@dataclass
class RBLNDPMetadata(DPMetadata):
max_pads_across_dp: int = 0
max_pads_across_dp: torch.Tensor | None = None

@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
dp_rank: int) -> torch.Tensor:
"""
Gather the num_tokens across all DP ranks and return results in a
CPU tensor of size dp_size.
"""
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
return num_tokens_tensor

@staticmethod
def num_tokens_across_dp_with_max_decode_tokens(
num_tokens: int, dp_size: int, dp_rank: int,
is_prefill: bool) -> tuple[torch.Tensor, int | None]:
pad_flag = 1 << 16
pad_mask = pad_flag - 1
assert num_tokens < pad_flag, \
"num_tokens should be less than pad_flag"

if is_prefill:
num_tokens |= pad_flag

tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_across_dp = torch.max(tokens_across_dp_cpu).item()

if is_prefill or max_across_dp > pad_flag:
mask_tensor = torch.tensor([pad_mask] * dp_size,
device="cpu",
dtype=torch.int32)
num_tokens_across_dp_cpu = tokens_across_dp_cpu & mask_tensor
max_across_dp = None
else:
num_tokens_across_dp_cpu = tokens_across_dp_cpu

return num_tokens_across_dp_cpu, max_across_dp

@staticmethod
def make(
vllm_config: VllmConfig,
attn_metadata: Any,
parallel_config: ParallelConfig,
num_tokens: int,
num_tokens_across_dp_cpu: torch.Tensor
num_tokens_across_dp: torch.Tensor | None = None,
num_padded_tokens: int | None = None,
) -> "RBLNDPMetadata":

parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank

scheduler_config = vllm_config.scheduler_config
max_pad = scheduler_config.max_num_batched_tokens

if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens

disable_dp = dp_size == 1
use_dummy_prefill = envs.VLLM_RBLN_DP_IMPL == "dummy_prefill"
if (disable_dp or use_dummy_prefill) and \
attn_metadata.num_decode_tokens > 0:
max_pad = scheduler_config.max_num_seqs
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens

# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert num_tokens_across_dp_cpu is not None
if dp_size > 1:
assert num_tokens_across_dp is not None, \
"num_tokens_across_dp should be applied for DP case"
assert num_padded_tokens is not None, \
"num_padded_tokens should be applied for DP case"
num_tokens_across_dp_cpu = num_tokens_across_dp
max_pad = num_padded_tokens

max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
max_pads_across_dp = torch.empty(max_pad, device="cpu")
else:
assert num_tokens_across_dp is None, \
"num_tokens_across_dp should not be applied for non-DP case"
assert num_padded_tokens is None, \
"num_padded_tokens should not be applied for non-DP case"
num_tokens_across_dp_cpu = torch.tensor([num_tokens],
device="cpu",
dtype=torch.int32)
max_tokens_across_dp_cpu = num_tokens
max_pads_across_dp = None

max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
return RBLNDPMetadata(max_tokens_across_dp_cpu=max_tokens_across_dp_cpu,
num_tokens_across_dp_cpu=num_tokens_across_dp_cpu,
max_pads_across_dp=max_pad)
return RBLNDPMetadata(max_tokens_across_dp_cpu,
num_tokens_across_dp_cpu,
max_pads_across_dp=max_pads_across_dp)


@contextmanager
def _set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
num_padded_tokens: int | None = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
vfc.forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None

dp_metadata: DPMetadata | None = None
enable_dp = vllm_config.parallel_config.data_parallel_size > 1
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
if (enable_dp or use_moe_tokens_mask) and (attn_metadata is not None
or num_tokens is not None):
dp_metadata = RBLNDPMetadata.make(vllm_config, attn_metadata,
dp_metadata = RBLNDPMetadata.make(vllm_config.parallel_config,
num_tokens or 0,
num_tokens_across_dp)

prev_context = vfc._forward_context
vfc._forward_context = ForwardContext(
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
num_tokens_across_dp,
num_padded_tokens)

forward_context = create_forward_context(
attn_metadata,
vllm_config,
virtual_engine,
dp_metadata,
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
)

try:
yield
with override_forward_context(forward_context):
yield
finally:
if need_to_track_batchsize:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = num_tokens
batchsize = num_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
from vllm.platforms import current_platform

synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
Expand All @@ -147,7 +189,5 @@ def _set_forward_context(
"(batchsize, count, median_time(ms)): %s"),
forward_stats)

vfc._forward_context = prev_context


vfc.set_forward_context = _set_forward_context
Loading