Skip to content

Commit 1911d49

Browse files
rebel-wonsubkimrebel-ykchoirebel-jaehwanghuijjjrebel-jaehunryu
committed
update gpt-oss-0.12 changes (#293)
* fix: bump up v0 moe dp implementation to v1 - remove DP padding support in v1 worker - add validation for DP implementation constraints in v1 worker - apply token mask to custom MOE kernel router logits - update default environment variables: - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode" - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True - fix DP metadata handling in forward context - add is_prefills field to RBLNFlashAttentionMetadata * fix: mxfp4 kernel for model parallel + add expert_map to handle vllm model parallel Signed-off-by: wonsub kim <subang0@rebellions.ai> * modify expert_map position Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix gpt_oss tensor parallel all_reduce + gpt_oss MLPBlock tp missing Signed-off-by: wonsub kim <subang0@rebellions.ai> * disable shared fused moe overlap for RBLN Signed-off-by: wonsub kim <subang0@rebellions.ai> * reference torch impl for gpt-oss ops * apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE * adjust available dram size based on target arch + change available dram size for REBEL architecture - ATOM - 16GB - REBEL - 140GB Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix v1 dp online serving refactor: improve intermediate tensors management and dummy run logic - add prepare_dummy_run and dummy_run methods for v1 dp online serving - remove unused sync_and_slice_intermediate_tensors method - separate intermediate_tensors into prefill_intermediate_tensors and decode_intermediate_tensors - improve RBLNWorker device environment initialization - add support for Ray backend - add local_world_size calculation - improve device environment variable setup logic - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE - change LOCAL_RANK to rank in init_worker_distributed_environment * add additional params for data_parallel.py script + add necessary parameters --max-model-len, --block-size, --num-hidden-layers, --decode-batch Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix calculation of maximum num blocks + consider sliding window attention - DO NOT count sliding window attention block since it shares kv cache block with full attention + calculate max num blocks based on assumption that entire layers have full attention - when calculating available memory, count full attention layer not sliding window attention Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix: port v0.12 scheduler code * fix: limit decode bs to (max num seqs // pp size) * tmp: pad decode inputs to max_num_seqs // pp_size * add: simple offline benchmark script * fix DPMetadata for tokens mask - remove unused attn_metadata parameter from RBLNDPMetadata.make() - remove is_prefills field and related logic from DP metadata - fix get_tokens_mask() for non-DP case * fix dp with pp dummy run logic - refactor dummy run execution with DummyRunState and prepare_dummy_run - update batch size calculation to account for pipeline parallel size - add batch_pad parameter to attention metadata builder for PP support * fix max_num_blocks calculation + consider following issues when calculating max_num_blocks - consider gpt-oss-20b scale merge for dequantized version - consider SWA(sliding window attention) block share with full attention - consider word_embedding param when calculating kernel size it is not included into device Signed-off-by: wonsub kim <subang0@rebellions.ai> * add optimized batch attention kernel + batch_attention kernel is optimized version of flash attention kernel for large batch - batch attention kernel takes original sequence index - in compiler lowering, original sequence index is lowered into following itmes - seq_idx - cache target block index - seq_offset - cache target block offset - dyn_batch - valid batch count for each partition Signed-off-by: wonsub kim <subang0@rebellions.ai> * resolve conflict between bucketing and dp + replace max_batch_size with decode_batch_bucket size + by default, disable batch bucketing - change limit of bucket Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix num_runtimes + num_runtimes fix up - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE - REBEL num_runtimes = 2 * 4 (quad chiplet) Signed-off-by: wonsub kim <subang0@rebellions.ai> * pad seq_idx for batch attention + seq_idx SHOULD be padded if num_reqs < decode_batch size Signed-off-by: wonsub kim <subang0@rebellions.ai> * fixed batched decode func call * remove unused code * fix up RBLN_METRICS + DO NOT count model warm up (prefill & decode batch bucket) Signed-off-by: wonsub kim <subang0@rebellions.ai> * fix typo Signed-off-by: wonsub kim <subang0@rebellions.ai> * add specialized MoE decode optimization for DP - implement specialized decode path that uses optimized padding when all requests are in decode stage - add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable specialized handling for decode-only batches in MoE models - refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor to differentiate speicalized decode and normal decode - add num_padded_tokens parameter to RBLNDPMetadata.make() and _set_forward_context() - add specialized decode path to batch bucketing --------- Signed-off-by: wonsub kim <subang0@rebellions.ai> Co-authored-by: Youngkyu Choi <youngkyu.choi@rebellions.ai> Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai> Co-authored-by: Huijong JEONG <huijong.jeong@squeezebits.com> Co-authored-by: JaehunRyu <jaehun.ryu@rebellions.ai>
1 parent 9cc2ca5 commit 1911d49

13 files changed

Lines changed: 1501 additions & 422 deletions

File tree

examples/experimental/data_parallel.py

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -54,55 +54,28 @@
5454

5555
os.environ['VLLM_TORCH_PROFILER_DIR'] = './profile'
5656

57-
hf_overrides_kw = {
58-
"num_hidden_layers": 2,
59-
}
60-
61-
6257
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
63-
dp_master_port, tp_size, enable_ep, vllm_use_v1):
58+
dp_master_port, tp_size, enable_ep,
59+
max_model_len, block_size, decode_batch, num_hidden_layers):
6460
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
6561
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
6662
# paralle_config.data_parallel_size = envs.sVLLM_DP_SIZE
6763
os.environ["VLLM_DP_SIZE"] = str(dp_size)
6864
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
6965
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
7066

71-
if not vllm_use_v1:
72-
# in v0 worker, each process has distinct RBLN_DEVICES
73-
rbln_devices = ""
74-
if os.environ.get("VLLM_RBLN_TP_SIZE") is None:
75-
rsd_size = 1
76-
else:
77-
rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE"))
78-
rsd_tp_size = tp_size * rsd_size
79-
start_index = local_dp_rank * rsd_tp_size
80-
end_index = start_index + rsd_tp_size
81-
for index in range(start_index, end_index):
82-
if rbln_devices:
83-
rbln_devices += ","
84-
rbln_devices += str(index)
85-
86-
os.environ["RBLN_DEVICES"] = rbln_devices
87-
else:
88-
rbln_devices = os.environ.get("RBLN_DEVICES")
89-
90-
print(f"local RBLN_DEVICES = {rbln_devices}")
91-
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
92-
# engine processes.
93-
9467
# Sample prompts.
9568
prompts = [
9669
"Hello, my name is",
9770
"The vLLM is",
9871
"The president of the United States is",
9972
"The future of AI is",
100-
]
73+
] * dp_size
10174

10275
# with DP, each rank should process different prompts.
10376
# usually all the DP ranks process a full dataset,
10477
# and each rank processes a different part of the dataset.
105-
prompts_per_rank = (len(prompts) // dp_size) + 1
78+
prompts_per_rank = (len(prompts) // dp_size)
10679
start = global_dp_rank * prompts_per_rank
10780
end = start + prompts_per_rank
10881
prompts = prompts[start:end]
@@ -119,15 +92,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
11992
# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
12093
sampling_params = SamplingParams(temperature=0.0)
12194

95+
if num_hidden_layers == 0:
96+
hf_overrides_kw = None
97+
else:
98+
hf_overrides_kw = {
99+
"num_hidden_layers": num_hidden_layers,
100+
}
101+
122102
# Create an LLM.
123103
llm = LLM(
124104
model=model,
125-
#hf_overrides=hf_overrides_kw,
126-
max_model_len=8 * 1024,
127-
block_size=1024,
105+
hf_overrides=hf_overrides_kw,
106+
max_model_len=max_model_len,
107+
block_size=block_size,
128108
enable_chunked_prefill=True,
129109
max_num_batched_tokens=128,
130-
max_num_seqs=1,
110+
max_num_seqs=decode_batch,
131111
trust_remote_code=True,
132112
tensor_parallel_size=tp_size,
133113
enable_expert_parallel=enable_ep,
@@ -166,6 +146,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
166146
parser.add_argument('--ep',
167147
action='store_true',
168148
help="vLLM enable_expert_parallel")
149+
parser.add_argument("--max-model-len",
150+
type=int,
151+
default=8192,
152+
help="Max sequence length")
153+
parser.add_argument("--block-size",
154+
type=int,
155+
default=4096,
156+
help="KV cache block size")
157+
parser.add_argument("--decode-batch",
158+
type=int,
159+
default=1,
160+
help="decode batch size")
161+
parser.add_argument("--num-hidden-layers",
162+
type=int,
163+
default=0,
164+
help="num hidden layers")
169165
parser.add_argument("--node-size",
170166
type=int,
171167
default=1,
@@ -189,6 +185,10 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
189185
node_size = args.node_size
190186
node_rank = args.node_rank
191187
enable_ep = args.ep
188+
max_model_len = args.max_model_len
189+
block_size = args.block_size
190+
decode_batch = args.decode_batch
191+
num_hidden_layers = args.num_hidden_layers
192192

193193
if node_size == 1:
194194
dp_master_ip = "127.0.0.1"
@@ -200,35 +200,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
200200
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
201201
dp_per_node = dp_size // node_size
202202

203-
vllm_use_v1 = (int(os.environ.get("VLLM_USE_V1", "0")) == 1)
204-
if vllm_use_v1:
205-
print("VLLM_USE_V1")
206-
# in v1 worker, entire processes SHOULD have global RBLN_DEVICES
207-
rbln_devices = ""
208-
if os.environ.get("VLLM_RBLN_TP_SIZE") is None:
209-
rsd_size = 1
210-
else:
211-
rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE"))
212-
start_index = 0
213-
end_index = start_index + tp_size * dp_size * rsd_size
214-
for index in range(start_index, end_index):
215-
if rbln_devices:
216-
rbln_devices += ","
217-
rbln_devices += str(index)
218-
219-
print(f"global RBLN_DEVICES = {rbln_devices}")
220-
os.environ["RBLN_DEVICES"] = rbln_devices
221-
else:
222-
print("VLLM_USE_V0")
223-
224203
from multiprocessing import Process
225204
procs = []
226205
for local_dp_rank, global_dp_rank in enumerate(
227206
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
228207
proc = Process(target=main,
229208
args=(args.model, dp_size, local_dp_rank,
230209
global_dp_rank, dp_master_ip, dp_master_port,
231-
tp_size, enable_ep, vllm_use_v1))
210+
tp_size, enable_ep,
211+
max_model_len, block_size, decode_batch, num_hidden_layers))
232212
proc.start()
233213
procs.append(proc)
234214
exit_code = 0

vllm_rbln/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def register_ops():
4747
import vllm_rbln.forward_context # noqa
4848
import vllm_rbln.lora.layer # noqa
4949
import vllm_rbln.model_executor.layers.fused_moe.layer # noqa
50+
import vllm_rbln.model_executor.layers.fused_moe.shared_fused_moe # noqa
5051
import vllm_rbln.model_executor.layers.logits_processor # noqa
5152
import vllm_rbln.model_executor.layers.quantization.kernels.mixed_precision # noqa
5253
import vllm_rbln.model_executor.layers.quantization.mxfp4 # noqa

vllm_rbln/forward_context.py

Lines changed: 104 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import time
1616
from contextlib import contextmanager
1717
from dataclasses import dataclass
18-
from typing import Any, Optional
18+
from typing import Any
1919

2020
import torch
21+
import torch.distributed as dist
2122
import vllm.forward_context as vfc
22-
from vllm.config import CUDAGraphMode, VllmConfig
23-
from vllm.forward_context import (BatchDescriptor, DPMetadata, ForwardContext,
24-
batchsize_logging_interval, track_batchsize)
23+
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
24+
from vllm.forward_context import (BatchDescriptor, DPMetadata,
25+
batchsize_logging_interval,
26+
create_forward_context,
27+
override_forward_context, track_batchsize)
28+
from vllm.v1.worker.ubatch_utils import UBatchSlices
2529

2630
import vllm_rbln.rbln_envs as envs
2731
from vllm_rbln.logger import init_logger
@@ -31,93 +35,137 @@
3135

3236
@dataclass
3337
class RBLNDPMetadata(DPMetadata):
34-
max_pads_across_dp: int = 0
38+
max_pads_across_dp: torch.Tensor | None = None
39+
40+
@staticmethod
41+
def num_tokens_across_dp(num_tokens: int, dp_size: int,
42+
dp_rank: int) -> torch.Tensor:
43+
"""
44+
Gather the num_tokens across all DP ranks and return results in a
45+
CPU tensor of size dp_size.
46+
"""
47+
num_tokens_across_dp = [0] * dp_size
48+
num_tokens_across_dp[dp_rank] = num_tokens
49+
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
50+
device="cpu",
51+
dtype=torch.int32)
52+
from vllm.distributed.parallel_state import get_dp_group
53+
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
54+
return num_tokens_tensor
55+
56+
@staticmethod
57+
def num_tokens_across_dp_with_max_decode_tokens(
58+
num_tokens: int, dp_size: int, dp_rank: int,
59+
is_prefill: bool) -> tuple[torch.Tensor, int | None]:
60+
pad_flag = 1 << 16
61+
pad_mask = pad_flag - 1
62+
assert num_tokens < pad_flag, \
63+
"num_tokens should be less than pad_flag"
64+
65+
if is_prefill:
66+
num_tokens |= pad_flag
67+
68+
tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp(
69+
num_tokens, dp_size, dp_rank)
70+
max_across_dp = torch.max(tokens_across_dp_cpu).item()
71+
72+
if is_prefill or max_across_dp > pad_flag:
73+
mask_tensor = torch.tensor([pad_mask] * dp_size,
74+
device="cpu",
75+
dtype=torch.int32)
76+
num_tokens_across_dp_cpu = tokens_across_dp_cpu & mask_tensor
77+
max_across_dp = None
78+
else:
79+
num_tokens_across_dp_cpu = tokens_across_dp_cpu
80+
81+
return num_tokens_across_dp_cpu, max_across_dp
3582

3683
@staticmethod
3784
def make(
38-
vllm_config: VllmConfig,
39-
attn_metadata: Any,
85+
parallel_config: ParallelConfig,
4086
num_tokens: int,
41-
num_tokens_across_dp_cpu: torch.Tensor
87+
num_tokens_across_dp: torch.Tensor | None = None,
88+
num_padded_tokens: int | None = None,
4289
) -> "RBLNDPMetadata":
43-
44-
parallel_config = vllm_config.parallel_config
4590
dp_size = parallel_config.data_parallel_size
46-
dp_rank = parallel_config.data_parallel_rank
4791

48-
scheduler_config = vllm_config.scheduler_config
49-
max_pad = scheduler_config.max_num_batched_tokens
92+
if dp_size > 1:
93+
assert num_tokens_across_dp is not None, \
94+
"num_tokens_across_dp should be applied for DP case"
95+
assert num_padded_tokens is not None, \
96+
"num_padded_tokens should be applied for DP case"
97+
num_tokens_across_dp_cpu = num_tokens_across_dp
98+
max_pad = num_padded_tokens
5099

51-
if attn_metadata is not None and hasattr(attn_metadata,
52-
"num_prefill_tokens"):
53-
# for v0 attention backends
54-
batchsize = attn_metadata.num_prefill_tokens + \
55-
attn_metadata.num_decode_tokens
100+
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
101+
max_pads_across_dp = torch.empty(max_pad, device="cpu")
56102
else:
57-
# for v1 attention backends or no attn_metadata
58-
batchsize = num_tokens
59-
60-
# If num_tokens_across_dp is None, it will be computed by all_reduce
61-
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
62-
assert num_tokens_across_dp_cpu is not None
103+
assert num_tokens_across_dp is None, \
104+
"num_tokens_across_dp should not be applied for non-DP case"
105+
assert num_padded_tokens is None, \
106+
"num_padded_tokens should not be applied for non-DP case"
107+
num_tokens_across_dp_cpu = torch.tensor([num_tokens],
108+
device="cpu",
109+
dtype=torch.int32)
110+
max_tokens_across_dp_cpu = num_tokens
111+
max_pads_across_dp = None
63112

64-
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
65-
return RBLNDPMetadata(max_tokens_across_dp_cpu=max_tokens_across_dp_cpu,
66-
num_tokens_across_dp_cpu=num_tokens_across_dp_cpu,
67-
max_pads_across_dp=max_pad)
113+
return RBLNDPMetadata(max_tokens_across_dp_cpu,
114+
num_tokens_across_dp_cpu,
115+
max_pads_across_dp=max_pads_across_dp)
68116

69117

70118
@contextmanager
71119
def _set_forward_context(
72-
attn_metadata: Any,
73-
vllm_config: VllmConfig,
74-
virtual_engine: int = 0,
75-
num_tokens: Optional[int] = None,
76-
num_tokens_across_dp: Optional[torch.Tensor] = None,
77-
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
78-
batch_descriptor: Optional[BatchDescriptor] = None):
120+
attn_metadata: Any,
121+
vllm_config: VllmConfig,
122+
virtual_engine: int = 0,
123+
num_tokens: int | None = None,
124+
num_tokens_across_dp: torch.Tensor | None = None,
125+
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
126+
batch_descriptor: BatchDescriptor | None = None,
127+
ubatch_slices: UBatchSlices | None = None,
128+
num_padded_tokens: int | None = None,
129+
):
79130
"""A context manager that stores the current forward context,
80131
can be attention metadata, etc.
81132
Here we can inject common logic for every model forward pass.
82133
"""
83134
need_to_track_batchsize = track_batchsize and attn_metadata is not None
84135
if need_to_track_batchsize:
85136
vfc.forward_start_time = time.perf_counter()
86-
dp_metadata: Optional[DPMetadata] = None
137+
138+
dp_metadata: DPMetadata | None = None
87139
enable_dp = vllm_config.parallel_config.data_parallel_size > 1
88140
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK
89141
if (enable_dp or use_moe_tokens_mask) and (attn_metadata is not None
90142
or num_tokens is not None):
91-
dp_metadata = RBLNDPMetadata.make(vllm_config, attn_metadata,
143+
dp_metadata = RBLNDPMetadata.make(vllm_config.parallel_config,
92144
num_tokens or 0,
93-
num_tokens_across_dp)
94-
95-
prev_context = vfc._forward_context
96-
vfc._forward_context = ForwardContext(
97-
no_compile_layers=vllm_config.compilation_config.
98-
static_forward_context,
99-
virtual_engine=virtual_engine,
100-
attn_metadata=attn_metadata,
101-
dp_metadata=dp_metadata,
102-
cudagraph_runtime_mode=cudagraph_runtime_mode,
103-
batch_descriptor=batch_descriptor,
145+
num_tokens_across_dp,
146+
num_padded_tokens)
147+
148+
forward_context = create_forward_context(
149+
attn_metadata,
150+
vllm_config,
151+
virtual_engine,
152+
dp_metadata,
153+
cudagraph_runtime_mode,
154+
batch_descriptor,
155+
ubatch_slices,
104156
)
105157

106158
try:
107-
yield
159+
with override_forward_context(forward_context):
160+
yield
108161
finally:
109162
if need_to_track_batchsize:
110-
if hasattr(attn_metadata, "num_prefill_tokens"):
111-
# for v0 attention backends
112-
batchsize = attn_metadata.num_prefill_tokens + \
113-
attn_metadata.num_decode_tokens
114-
else:
115-
# for v1 attention backends
116-
batchsize = num_tokens
163+
batchsize = num_tokens
117164
# we use synchronous scheduling right now,
118165
# adding a sync point here should not affect
119166
# scheduling of the next batch
120167
from vllm.platforms import current_platform
168+
121169
synchronize = current_platform.synchronize
122170
if synchronize is not None:
123171
synchronize()
@@ -141,7 +189,5 @@ def _set_forward_context(
141189
"(batchsize, count, median_time(ms)): %s"),
142190
forward_stats)
143191

144-
vfc._forward_context = prev_context
145-
146192

147193
vfc.set_forward_context = _set_forward_context

0 commit comments

Comments
 (0)