diff --git a/examples/experimental/data_parallel.py b/examples/experimental/data_parallel.py index a66e51d17..e10b35ace 100644 --- a/examples/experimental/data_parallel.py +++ b/examples/experimental/data_parallel.py @@ -54,13 +54,9 @@ os.environ['VLLM_TORCH_PROFILER_DIR'] = './profile' -hf_overrides_kw = { - "num_hidden_layers": 2, -} - - def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, tp_size, enable_ep, vllm_use_v1): + dp_master_port, tp_size, enable_ep, + max_model_len, block_size, decode_batch, num_hidden_layers): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) # paralle_config.data_parallel_size = envs.sVLLM_DP_SIZE @@ -68,41 +64,18 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) - if not vllm_use_v1: - # in v0 worker, each process has distinct RBLN_DEVICES - rbln_devices = "" - if os.environ.get("VLLM_RBLN_TP_SIZE") is None: - rsd_size = 1 - else: - rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE")) - rsd_tp_size = tp_size * rsd_size - start_index = local_dp_rank * rsd_tp_size - end_index = start_index + rsd_tp_size - for index in range(start_index, end_index): - if rbln_devices: - rbln_devices += "," - rbln_devices += str(index) - - os.environ["RBLN_DEVICES"] = rbln_devices - else: - rbln_devices = os.environ.get("RBLN_DEVICES") - - print(f"local RBLN_DEVICES = {rbln_devices}") - # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the - # engine processes. - # Sample prompts. prompts = [ "Hello, my name is", "The vLLM is", "The president of the United States is", "The future of AI is", - ] + ] * dp_size # with DP, each rank should process different prompts. # usually all the DP ranks process a full dataset, # and each rank processes a different part of the dataset. - prompts_per_rank = (len(prompts) // dp_size) + 1 + prompts_per_rank = (len(prompts) // dp_size) start = global_dp_rank * prompts_per_rank end = start + prompts_per_rank prompts = prompts[start:end] @@ -119,15 +92,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, # sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.0) + if num_hidden_layers == 0: + hf_overrides_kw = None + else: + hf_overrides_kw = { + "num_hidden_layers": num_hidden_layers, + } + # Create an LLM. llm = LLM( model=model, - #hf_overrides=hf_overrides_kw, - max_model_len=8 * 1024, - block_size=1024, + hf_overrides=hf_overrides_kw, + max_model_len=max_model_len, + block_size=block_size, enable_chunked_prefill=True, max_num_batched_tokens=128, - max_num_seqs=1, + max_num_seqs=decode_batch, trust_remote_code=True, tensor_parallel_size=tp_size, enable_expert_parallel=enable_ep, @@ -166,6 +146,22 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, parser.add_argument('--ep', action='store_true', help="vLLM enable_expert_parallel") + parser.add_argument("--max-model-len", + type=int, + default=8192, + help="Max sequence length") + parser.add_argument("--block-size", + type=int, + default=4096, + help="KV cache block size") + parser.add_argument("--decode-batch", + type=int, + default=1, + help="decode batch size") + parser.add_argument("--num-hidden-layers", + type=int, + default=0, + help="num hidden layers") parser.add_argument("--node-size", type=int, default=1, @@ -189,6 +185,10 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, node_size = args.node_size node_rank = args.node_rank enable_ep = args.ep + max_model_len = args.max_model_len + block_size = args.block_size + decode_batch = args.decode_batch + num_hidden_layers = args.num_hidden_layers if node_size == 1: dp_master_ip = "127.0.0.1" @@ -200,27 +200,6 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, assert dp_size % node_size == 0, "dp_size should be divisible by node_size" dp_per_node = dp_size // node_size - vllm_use_v1 = (int(os.environ.get("VLLM_USE_V1", "0")) == 1) - if vllm_use_v1: - print("VLLM_USE_V1") - # in v1 worker, entire processes SHOULD have global RBLN_DEVICES - rbln_devices = "" - if os.environ.get("VLLM_RBLN_TP_SIZE") is None: - rsd_size = 1 - else: - rsd_size = int(os.environ.get("VLLM_RBLN_TP_SIZE")) - start_index = 0 - end_index = start_index + tp_size * dp_size * rsd_size - for index in range(start_index, end_index): - if rbln_devices: - rbln_devices += "," - rbln_devices += str(index) - - print(f"global RBLN_DEVICES = {rbln_devices}") - os.environ["RBLN_DEVICES"] = rbln_devices - else: - print("VLLM_USE_V0") - from multiprocessing import Process procs = [] for local_dp_rank, global_dp_rank in enumerate( @@ -228,7 +207,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size, enable_ep, vllm_use_v1)) + tp_size, enable_ep, + max_model_len, block_size, decode_batch, num_hidden_layers)) proc.start() procs.append(proc) exit_code = 0 diff --git a/vllm_rbln/__init__.py b/vllm_rbln/__init__.py index 34ef39129..adb844fcc 100644 --- a/vllm_rbln/__init__.py +++ b/vllm_rbln/__init__.py @@ -47,6 +47,7 @@ def register_ops(): import vllm_rbln.forward_context # noqa import vllm_rbln.lora.layer # noqa import vllm_rbln.model_executor.layers.fused_moe.layer # noqa + import vllm_rbln.model_executor.layers.fused_moe.shared_fused_moe # noqa import vllm_rbln.model_executor.layers.logits_processor # noqa import vllm_rbln.model_executor.layers.quantization.kernels.mixed_precision # noqa import vllm_rbln.model_executor.layers.quantization.mxfp4 # noqa diff --git a/vllm_rbln/forward_context.py b/vllm_rbln/forward_context.py index ae02f3702..42bd8715f 100644 --- a/vllm_rbln/forward_context.py +++ b/vllm_rbln/forward_context.py @@ -15,13 +15,17 @@ import time from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import torch +import torch.distributed as dist import vllm.forward_context as vfc -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.forward_context import (BatchDescriptor, DPMetadata, ForwardContext, - batchsize_logging_interval, track_batchsize) +from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.forward_context import (BatchDescriptor, DPMetadata, + batchsize_logging_interval, + create_forward_context, + override_forward_context, track_batchsize) +from vllm.v1.worker.ubatch_utils import UBatchSlices import vllm_rbln.rbln_envs as envs from vllm_rbln.logger import init_logger @@ -31,57 +35,98 @@ @dataclass class RBLNDPMetadata(DPMetadata): - max_pads_across_dp: int = 0 + max_pads_across_dp: torch.Tensor | None = None + + @staticmethod + def num_tokens_across_dp(num_tokens: int, dp_size: int, + dp_rank: int) -> torch.Tensor: + """ + Gather the num_tokens across all DP ranks and return results in a + CPU tensor of size dp_size. + """ + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = num_tokens + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + return num_tokens_tensor + + @staticmethod + def num_tokens_across_dp_with_max_decode_tokens( + num_tokens: int, dp_size: int, dp_rank: int, + is_prefill: bool) -> tuple[torch.Tensor, int | None]: + pad_flag = 1 << 16 + pad_mask = pad_flag - 1 + assert num_tokens < pad_flag, \ + "num_tokens should be less than pad_flag" + + if is_prefill: + num_tokens |= pad_flag + + tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_across_dp = torch.max(tokens_across_dp_cpu).item() + + if is_prefill or max_across_dp > pad_flag: + mask_tensor = torch.tensor([pad_mask] * dp_size, + device="cpu", + dtype=torch.int32) + num_tokens_across_dp_cpu = tokens_across_dp_cpu & mask_tensor + max_across_dp = None + else: + num_tokens_across_dp_cpu = tokens_across_dp_cpu + + return num_tokens_across_dp_cpu, max_across_dp @staticmethod def make( - vllm_config: VllmConfig, - attn_metadata: Any, + parallel_config: ParallelConfig, num_tokens: int, - num_tokens_across_dp_cpu: torch.Tensor + num_tokens_across_dp: torch.Tensor | None = None, + num_padded_tokens: int | None = None, ) -> "RBLNDPMetadata": - - parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size - dp_rank = parallel_config.data_parallel_rank - - scheduler_config = vllm_config.scheduler_config - max_pad = scheduler_config.max_num_batched_tokens - - if attn_metadata is not None and hasattr(attn_metadata, - "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - - disable_dp = dp_size == 1 - use_dummy_prefill = envs.VLLM_RBLN_DP_IMPL == "dummy_prefill" - if (disable_dp or use_dummy_prefill) and \ - attn_metadata.num_decode_tokens > 0: - max_pad = scheduler_config.max_num_seqs - else: - # for v1 attention backends or no attn_metadata - batchsize = num_tokens - # If num_tokens_across_dp is None, it will be computed by all_reduce - # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert num_tokens_across_dp_cpu is not None + if dp_size > 1: + assert num_tokens_across_dp is not None, \ + "num_tokens_across_dp should be applied for DP case" + assert num_padded_tokens is not None, \ + "num_padded_tokens should be applied for DP case" + num_tokens_across_dp_cpu = num_tokens_across_dp + max_pad = num_padded_tokens + + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) + max_pads_across_dp = torch.empty(max_pad, device="cpu") + else: + assert num_tokens_across_dp is None, \ + "num_tokens_across_dp should not be applied for non-DP case" + assert num_padded_tokens is None, \ + "num_padded_tokens should not be applied for non-DP case" + num_tokens_across_dp_cpu = torch.tensor([num_tokens], + device="cpu", + dtype=torch.int32) + max_tokens_across_dp_cpu = num_tokens + max_pads_across_dp = None - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) - return RBLNDPMetadata(max_tokens_across_dp_cpu=max_tokens_across_dp_cpu, - num_tokens_across_dp_cpu=num_tokens_across_dp_cpu, - max_pads_across_dp=max_pad) + return RBLNDPMetadata(max_tokens_across_dp_cpu, + num_tokens_across_dp_cpu, + max_pads_across_dp=max_pads_across_dp) @contextmanager def _set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: int | None = None, + num_tokens_across_dp: torch.Tensor | None = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: BatchDescriptor | None = None, + ubatch_slices: UBatchSlices | None = None, + num_padded_tokens: int | None = None, +): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -89,41 +134,38 @@ def _set_forward_context( need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: vfc.forward_start_time = time.perf_counter() - dp_metadata: Optional[DPMetadata] = None + + dp_metadata: DPMetadata | None = None enable_dp = vllm_config.parallel_config.data_parallel_size > 1 use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK if (enable_dp or use_moe_tokens_mask) and (attn_metadata is not None or num_tokens is not None): - dp_metadata = RBLNDPMetadata.make(vllm_config, attn_metadata, + dp_metadata = RBLNDPMetadata.make(vllm_config.parallel_config, num_tokens or 0, - num_tokens_across_dp) - - prev_context = vfc._forward_context - vfc._forward_context = ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + num_tokens_across_dp, + num_padded_tokens) + + forward_context = create_forward_context( + attn_metadata, + vllm_config, + virtual_engine, + dp_metadata, + cudagraph_runtime_mode, + batch_descriptor, + ubatch_slices, ) try: - yield + with override_forward_context(forward_context): + yield finally: if need_to_track_batchsize: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends - batchsize = num_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch from vllm.platforms import current_platform + synchronize = current_platform.synchronize if synchronize is not None: synchronize() @@ -147,7 +189,5 @@ def _set_forward_context( "(batchsize, count, median_time(ms)): %s"), forward_stats) - vfc._forward_context = prev_context - vfc.set_forward_context = _set_forward_context diff --git a/vllm_rbln/model_executor/layers/fused_moe/layer.py b/vllm_rbln/model_executor/layers/fused_moe/layer.py index 2c5279f88..0647ae3e9 100644 --- a/vllm_rbln/model_executor/layers/fused_moe/layer.py +++ b/vllm_rbln/model_executor/layers/fused_moe/layer.py @@ -238,13 +238,17 @@ def unquantized_fused_moe_method_rbln( return final_hidden_states.reshape(orig_shape) -def _get_tokens_mask(): - num_tokens = \ +def get_tokens_mask(num_tokens: int, left=1.0, right=float('-inf')): + num_tokens_across_dp = \ get_forward_context().dp_metadata.num_tokens_across_dp_cpu - num_tokens = num_tokens.unsqueeze(1) - max_pad = get_forward_context().dp_metadata.max_pads_across_dp + num_tokens_across_dp = num_tokens_across_dp.unsqueeze(1) + if num_tokens_across_dp.size(0) == 1: + max_pad = num_tokens + else: + max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0] pos = torch.arange(max_pad, dtype=torch.int32).unsqueeze(0) # [1, max_pad] - tokens_mask = torch.where(pos < num_tokens, 1.0, 0.0) # [dp_size, max_pad] + tokens_mask = torch.where(pos < num_tokens_across_dp, left, + right) # [dp_size, max_pad] tokens_mask = tokens_mask.reshape(-1, 1) #[dp_size * max_pad, 1] return tokens_mask @@ -268,7 +272,7 @@ def get_masked_routing_weights(router_logits, top_k, renormalize, expert_map): use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK if use_moe_tokens_mask: - tokens_mask = _get_tokens_mask() + tokens_mask = get_tokens_mask(router_logits.shape[0], 1.0, 0.0) selected_weights = selected_weights * tokens_mask n_expert = router_logits.shape[1] @@ -393,6 +397,11 @@ def unquantized_fused_optimize_moe_method_custom( expert_map_list = expert_map.tolist() expert_map_const = torch.tensor(expert_map_list, dtype=torch.int32) + use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK + if use_moe_tokens_mask: + tokens_mask = get_tokens_mask(num_tokens) + router_logits = router_logits * tokens_mask + # optimum-rbln/src/optimum/rbln/transformers/models/qwen3_moe/ # qwen3_moe_architecture.py final_hidden_states = torch.ops.rbln_custom_ops.custom_moe_glu( @@ -455,7 +464,7 @@ def fused_moe_forward_rbln(self, hidden_states: torch.Tensor, hidden_shape_dp = (-1, 1, org_hidden_shape[-1]) final_hidden_states = all_hidden_states.reshape(hidden_shape_dp) - max_pad = get_forward_context().dp_metadata.max_pads_across_dp + max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0] num_tokens = org_hidden_shape[:-1].numel() # noqa: F841 start = self.dp_rank * max_pad end = start + num_tokens @@ -474,7 +483,7 @@ def fused_moe_naive_multicast_rbln(self, x: torch.Tensor): # assert len(x.shape) == 3 x = x.reshape(1, -1, x.size(-1)) - max_pad = get_forward_context().dp_metadata.max_pads_across_dp + max_pad = get_forward_context().dp_metadata.max_pads_across_dp.shape[0] num_tokens = x.size(1) num_repeat = max_pad // num_tokens # TODO: evaluate various padding approaches diff --git a/vllm_rbln/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm_rbln/model_executor/layers/fused_moe/shared_fused_moe.py new file mode 100644 index 000000000..262abad50 --- /dev/null +++ b/vllm_rbln/model_executor/layers/fused_moe/shared_fused_moe.py @@ -0,0 +1,41 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE + +import vllm_rbln.rbln_envs as envs +from vllm_rbln.logger import init_logger + +logger = init_logger(__name__) + + +def __shared_fused_moe_init_rbln( + self, + shared_experts: torch.nn.Module | None, + gate: torch.nn.Module | None = None, + use_overlapped: bool = True, + **kwargs, +): + FusedMoE.__init__(self, **kwargs) + self._shared_experts = shared_experts + + # FIXME(RBLN) - disable use overlapped, not supported + self.use_overlapped = False + + self._gate = gate + + +SharedFusedMoE.__init__ = __shared_fused_moe_init_rbln diff --git a/vllm_rbln/model_executor/layers/quantization/mxfp4.py b/vllm_rbln/model_executor/layers/quantization/mxfp4.py index c0805780a..e6480d1ed 100644 --- a/vllm_rbln/model_executor/layers/quantization/mxfp4.py +++ b/vllm_rbln/model_executor/layers/quantization/mxfp4.py @@ -2,83 +2,234 @@ import torch import vllm.model_executor.layers.quantization.mxfp4 as upstream -from torch import Tensor from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.utils import set_weight_attrs +import vllm_rbln.rbln_envs as envs from vllm_rbln.logger import init_logger +from vllm_rbln.model_executor.layers.fused_moe.layer import get_tokens_mask logger = init_logger(__name__) +def _dequantize_mxfp4(blocks: torch.Tensor, scales: torch.Tensor, + dtype: torch.dtype) -> torch.Tensor: + """ + Args: + blocks: uint8 [..., K // 2] containing packed FP4 values + scales: uint8 [..., K // 32] containing E8M0 scales + dtype: output dtype + + Returns: + Dequantized tensor of shape [..., K] + """ + # fmt: off + FP4_VALUES = [ + +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + ] + # fmt: on + lut = torch.tensor(FP4_VALUES, dtype=dtype) + + # Convert E8M0 scales to exponents (subtract bias of 127) + exponents = scales.to(torch.int32) - 127 + + # Unpack FP4 nibbles + idx_lo = (blocks & 0x0F).to(torch.long) + idx_hi = (blocks >> 4).to(torch.long) + + # Look up FP4 values + val_lo = lut[idx_lo] + val_hi = lut[idx_hi] + + # Interleave low and high nibbles + *prefix_shape, B = blocks.shape + out = torch.empty(*prefix_shape, B * 2, dtype=dtype) + out[..., 0::2] = val_lo + out[..., 1::2] = val_hi + + # Apply scales: each scale covers 32 elements + # scales shape: [..., K // 32], out shape: [..., K] + # Expand scales to match output shape + exponents_expanded = exponents.unsqueeze(-1).expand(*exponents.shape, 32) + exponents_expanded = exponents_expanded.reshape(*prefix_shape, -1) + + # ldexp: out * 2^exponents + out = torch.ldexp(out, exponents_expanded[..., :out.shape[-1]]) + + return out + + +def _swigluoai(gate: torch.Tensor, up: torch.Tensor, alpha: float, + limit: float) -> torch.Tensor: + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1) * glu + + # kernel for gpt_oss, with built-in swigluoai activation @torch.library.custom_op( "rbln_custom_ops::custom_moe_glu_mxfp4", mutates_args=(), ) def custom_moe_glu_mxfp4( - hidden_states: Tensor, - gate_proj_blocks: Tensor, - gate_proj_scales: Tensor, - gate_proj_bias: Tensor, - up_proj_blocks: Tensor, - up_proj_scales: Tensor, - up_proj_bias: Tensor, - down_proj_blocks: Tensor, - down_proj_scales: Tensor, - down_proj_bias: Tensor, - router_logits: Tensor, - alpha: Tensor, - limit: Tensor, + hidden_states: torch.Tensor, + gate_proj_blocks: torch.Tensor, + gate_proj_scales: torch.Tensor, + gate_proj_bias: torch.Tensor, + up_proj_blocks: torch.Tensor, + up_proj_scales: torch.Tensor, + up_proj_bias: torch.Tensor, + down_proj_blocks: torch.Tensor, + down_proj_scales: torch.Tensor, + down_proj_bias: torch.Tensor, + router_logits: torch.Tensor, + alpha: torch.Tensor, + limit: torch.Tensor, k: int, post_norm: bool = True, -) -> Tensor: + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ - Customized MoE GLU operation. + MoE GLU operation for GPT-OSS with mxfp4 quantization and swigluoai activation. Expected tensor shapes: - - hidden_states: [batch*seq_len, hidden_size] - - gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2] + - hidden_states: [num_tokens, hidden_size] + - gate_proj_blocks: uint8 [num_experts, intermediate_size, hidden_size // 2] - gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32] - gate_proj_bias: [num_experts, intermediate_size] - - up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2] + - up_proj_blocks: uint8 [num_experts, intermediate_size, hidden_size // 2] - up_proj_scales: [num_experts, intermediate_size, hidden_size // 32] - up_proj_bias: [num_experts, intermediate_size] - - down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2] + - down_proj_blocks: uint8 [num_experts, hidden_size, intermediate_size // 2] - down_proj_scales: [num_experts, hidden_size, intermediate_size // 32] - - masked_routing_weight: [batch * seq_len, num_experts] - - expert_select_count: [num_experts] - - alpha: [] - - limit: [] + - down_proj_bias: [num_experts, hidden_size] + - router_logits: [num_tokens, num_experts] + - alpha: [], constant + - limit: [], constant + - expert_map: [num_experts], + Mapping from global expert index to local expert index (in num_experts). + Contains -1 for experts not assigned to the current rank. Returns: - Tensor: [batch * seq_len, hidden_size] + torch.Tensor: [num_tokens, hidden_size] """ - return torch.empty_like(hidden_states) + if envs.VLLM_RBLN_COMPILE_MODEL: + return torch.empty_like(hidden_states) + + # Reference torch native implementation + + num_tokens, hidden_size = hidden_states.shape + num_local_experts = gate_proj_blocks.shape[0] + # num_global_experts = router_logits.shape[1] + dtype = hidden_states.dtype + + alpha_val = alpha.item() + limit_val = limit.item() + + # Compute top-k routing + # router_logits: [num_tokens, num_global_experts] + top_k_values, top_k_indices = torch.topk(router_logits, k, dim=-1) + + # Apply softmax to get routing weights (only over selected experts) + if post_norm: + routing_weights = torch.softmax(top_k_values, dim=-1) + else: + # Pre-norm: softmax over all experts, then select top-k + all_weights = torch.softmax(router_logits, dim=-1) + routing_weights = torch.gather(all_weights, + dim=-1, + index=top_k_indices) + + # Initialize output + output = torch.zeros(num_tokens, hidden_size, dtype=dtype) + + # Dequantize all expert weights once + # gate_proj: [num_local_experts, intermediate_size, hidden_size] + gate_proj_weights = _dequantize_mxfp4(gate_proj_blocks, + gate_proj_scales, + dtype=dtype) + up_proj_weights = _dequantize_mxfp4(up_proj_blocks, + up_proj_scales, + dtype=dtype) + down_proj_weights = _dequantize_mxfp4(down_proj_blocks, + down_proj_scales, + dtype=dtype) + + # Process each local expert + for local_expert_idx in range(num_local_experts): + # Determine which global expert this local expert corresponds to + if expert_map is not None: + # Find global expert index that maps to this local expert + global_expert_idx = (expert_map == local_expert_idx).nonzero( + as_tuple=True)[0] + if len(global_expert_idx) == 0: + continue + global_expert_idx = global_expert_idx[0].item() + else: + global_expert_idx = local_expert_idx + + # Find tokens routed to this expert + # top_k_indices: [num_tokens, k] + expert_mask = (top_k_indices == global_expert_idx) # [num_tokens, k] + token_indices, k_indices = expert_mask.nonzero(as_tuple=True) + + if len(token_indices) == 0: + continue + + # Get routing weights for these tokens + weights = routing_weights[token_indices, + k_indices] # [num_selected_tokens] + + # Get hidden states for selected tokens + selected_hidden = hidden_states[ + token_indices] # [num_selected, hidden_size] + + # Get expert weights + gate_w = gate_proj_weights[local_expert_idx] + gate_b = gate_proj_bias[local_expert_idx] + up_w = up_proj_weights[local_expert_idx] + up_b = up_proj_bias[local_expert_idx] + down_w = down_proj_weights[local_expert_idx] + down_b = down_proj_bias[local_expert_idx] + + # Forward pass through expert MLP + gate = selected_hidden @ gate_w.T + gate_b + up = selected_hidden @ up_w.T + up_b + activated = _swigluoai(gate, up, alpha_val, limit_val) + expert_out = activated @ down_w.T + down_b # [num_selected, hidden_size] + + # Apply routing weights and accumulate + weighted_out = expert_out * weights.unsqueeze(-1) + output.index_add_(0, token_indices, weighted_out.to(dtype)) + + return output @custom_moe_glu_mxfp4.register_fake def custom_moe_glu_mxfp4_fake( - hidden_states: Tensor, - gate_proj_blocks: Tensor, - gate_proj_scales: Tensor, - gate_proj_bias: Tensor, - up_proj_blocks: Tensor, - up_proj_scales: Tensor, - up_proj_bias: Tensor, - down_proj_blocks: Tensor, - down_proj_scales: Tensor, - down_proj_bias: Tensor, - router_logits: Tensor, - alpha: Tensor, - limit: Tensor, + hidden_states: torch.Tensor, + gate_proj_blocks: torch.Tensor, + gate_proj_scales: torch.Tensor, + gate_proj_bias: torch.Tensor, + up_proj_blocks: torch.Tensor, + up_proj_scales: torch.Tensor, + up_proj_bias: torch.Tensor, + down_proj_blocks: torch.Tensor, + down_proj_scales: torch.Tensor, + down_proj_bias: torch.Tensor, + router_logits: torch.Tensor, + alpha: torch.Tensor, + limit: torch.Tensor, k: int, post_norm: bool = True, -) -> Tensor: + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -89,6 +240,12 @@ def __init__(self, moe: FusedMoEConfig): self.moe = moe self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + # swigluoai constant value + # gemm1_alpha = 1.702, gemm1_beta = 1.0, gemm1_clamp_limit = 7.0 + # gemm1_alpha = 1.702 + self.swiglu_alpha = torch.tensor(1.702, dtype=torch.float32) + # gemm1_clamp_limit = 7.0 + self.swiglu_limit = torch.tensor(7.0, dtype=torch.float32) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -239,14 +396,32 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - orig_shape = x.shape - x = x.view(-1, self.hidden_size) - router_logits = router_logits.view(-1, self.num_experts) + # refer to custom_moe_glu + orig_shape = x.shape # noqa: F841 + num_tokens = orig_shape[:-1].numel() # noqa: F841 + hidden_states = x.reshape(num_tokens, -1) + router_logits = router_logits.reshape(num_tokens, -1) + # x = x.view(-1, self.hidden_size) + # router_logits = router_logits.view(-1, self.num_experts) + # router_logits = router_logits.view(-1, self.moe.num_experts) if activation == "swigluoai": # TODO: use expert_map - output = torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4( - x, + # FIXME(RBLN) - expert_map SHOULD be processed + expert_map_const = None + if expert_map is not None: + # Extract numpy array and create a fresh constant tensor + expert_map_list = expert_map.tolist() + expert_map_const = torch.tensor(expert_map_list, + dtype=torch.int32) + + use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK + if use_moe_tokens_mask: + tokens_mask = get_tokens_mask(num_tokens) + router_logits = router_logits * tokens_mask + + final_hidden_states = torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4( + hidden_states, layer.gate_proj_blocks, layer.gate_proj_scales, layer.gate_proj_bias, @@ -257,16 +432,16 @@ def apply( layer.down_proj_scales, layer.down_proj_bias, router_logits, - # alpha (hardcoded in vllm as well?) - torch.tensor(1.702, dtype=x.dtype), - # swiglu_limit - torch.tensor(7.0, dtype=x.dtype), - k=top_k, + self.swiglu_alpha, + self.swiglu_limit, + top_k, + renormalize, + expert_map_const, ) else: raise NotImplementedError(activation) - return output.view(orig_shape) + return final_hidden_states.reshape(orig_shape) # We do this because upstream uses Mxfp4MoEMethod for all non-xpu platforms diff --git a/vllm_rbln/models/gpt_oss.py b/vllm_rbln/models/gpt_oss.py index 87e78b715..13a8fe2e5 100644 --- a/vllm_rbln/models/gpt_oss.py +++ b/vllm_rbln/models/gpt_oss.py @@ -1,7 +1,24 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + from collections.abc import Iterable import torch from vllm.distributed import (get_dp_group, get_pcp_group, + tensor_model_parallel_all_reduce, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -213,4 +230,16 @@ def _load_weights_mxfp4_custom( loaded_params.add(name) return loaded_params + + def __gpt_oss_moe_forward_rsd(self, hidden_states: torch.Tensor) -> torch.Tensor: + g = self.router(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=g) + tp_size = get_tensor_model_parallel_world_size() + if tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states + + gpt_oss.GptOssModel._load_weights_mxfp4 = _load_weights_mxfp4_custom + gpt_oss.MLPBlock.forward = __gpt_oss_moe_forward_rsd diff --git a/vllm_rbln/platform.py b/vllm_rbln/platform.py index 34c96d19e..c89f7fe5c 100644 --- a/vllm_rbln/platform.py +++ b/vllm_rbln/platform.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import TYPE_CHECKING, Optional import torch @@ -75,6 +74,14 @@ def get_device_name(cls, device_id: int = 0) -> str: def inference_mode(): return torch.no_grad() + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + logger.warning("set_device is not supported on RBLN.") + pass + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on RBLN.") @@ -271,26 +278,3 @@ def get_punica_wrapper(cls) -> str: @classmethod def can_update_inplace(cls) -> bool: return False - - @classmethod - def device_id_to_physical_device_id(cls, device_id: int): - # overrides for RSD (rbln scalable devices) - # dp device ids for RBLN SHOULD consider rsd size - rsd_size = envs.VLLM_RBLN_TP_SIZE - assert rsd_size >= 1 - if cls.device_control_env_var in os.environ and os.environ[ - cls.device_control_env_var] != "": - device_ids = os.environ[cls.device_control_env_var].split(",") - physical_device_ids = "" - start_device_id = device_id * rsd_size - for rsd_id in range(rsd_size - 1): - physical_device_ids += str(device_ids[start_device_id + - rsd_id]) - physical_device_ids += "," - physical_device_ids += str(device_ids[start_device_id + rsd_size - - 1]) - logger.info("RBLN DP physical_device_ids = %s", - physical_device_ids) - return physical_device_ids - else: - return device_id diff --git a/vllm_rbln/rbln_envs.py b/vllm_rbln/rbln_envs.py index 0140ebb92..8a18cfc12 100644 --- a/vllm_rbln/rbln_envs.py +++ b/vllm_rbln/rbln_envs.py @@ -24,10 +24,12 @@ VLLM_RBLN_SAMPLER: bool = True VLLM_RBLN_ENABLE_WARM_UP: bool = True VLLM_RBLN_USE_VLLM_MODEL: bool = False + VLLM_RBLN_SPECIALIZE_MOE_DECODE: bool = True VLLM_RBLN_FLASH_CAUSAL_ATTN: bool = True + VLLM_RBLN_BATCH_ATTN_OPT: bool = False VLLM_RBLN_DISABLE_MM: bool = False - VLLM_RBLN_DP_IMPL: str = "dummy_prefill" - VLLM_RBLN_USE_MOE_TOKENS_MASK: bool = False + VLLM_RBLN_DP_IMPL: str = "padded_decode" + VLLM_RBLN_USE_MOE_TOKENS_MASK: bool = True VLLM_RBLN_ENFORCE_MODEL_FP32: bool = False VLLM_RBLN_MOE_CUSTOM_KERNEL: bool = True VLLM_RBLN_MOE_USE_OPT_KERNEL: bool = False @@ -44,8 +46,9 @@ def get_dp_impl(): dp_impl = os.environ.get("VLLM_RBLN_DP_IMPL") if dp_impl is None: - return "dummy_prefill" - # default is dummy_prefill + return "padded_decode" + # default is padded_decode + # dummy_prefill will be deprecated in the future choices = set(["padded_decode", "dummy_prefill"]) current_impl = dp_impl.lower() if current_impl not in choices: @@ -85,6 +88,10 @@ def get_dp_impl(): "VLLM_RBLN_FLASH_CAUSAL_ATTN": (lambda: os.environ.get("VLLM_RBLN_FLASH_CAUSAL_ATTN", "True").lower() in ("true", "1")), + # Use batch attention optimization for paged attention + "VLLM_RBLN_BATCH_ATTN_OPT": + (lambda: os.environ.get("VLLM_RBLN_BATCH_ATTN_OPT", "False").lower() in + ("true", "1")), # Disable multimodal input "VLLM_RBLN_DISABLE_MM": (lambda: os.environ.get("VLLM_RBLN_DISABLE_MM", "False").lower() in @@ -93,8 +100,12 @@ def get_dp_impl(): "VLLM_RBLN_DP_IMPL": get_dp_impl, # If true, it uses the tokens mask applied to moe expert kernel - "VLLM_RBLN_USE_MOE_TOKENS_MASK": (lambda: os.environ.get( - "VLLM_RBLN_USE_MOE_TOKENS_MASK", "False").lower() in ("true", "1")), + "VLLM_RBLN_USE_MOE_TOKENS_MASK": + (lambda: os.environ.get("VLLM_RBLN_USE_MOE_TOKENS_MASK", "True").lower() in + ("true", "1")), + # If true, it specializes the cases where all instances are at decode stage + "VLLM_RBLN_SPECIALIZE_MOE_DECODE": (lambda: os.environ.get( + "VLLM_RBLN_SPECIALIZE_MOE_DECODE", "True").lower() in ("true", "1")), # enforce model data type into fp32 not model_config.dtype "VLLM_RBLN_ENFORCE_MODEL_FP32": (lambda: os.environ.get("VLLM_RBLN_ENFORCE_MODEL_FP32", "False").lower() in @@ -134,7 +145,7 @@ def get_dp_impl(): lambda: int(os.environ.get("VLLM_RBLN_DECODE_BATCH_BUCKET_STEP", 2)), # Decode batch bucket limit "VLLM_RBLN_DECODE_BATCH_BUCKET_LIMIT": - lambda: int(os.environ.get("VLLM_RBLN_DECODE_BATCH_BUCKET_LIMIT", 32)), + lambda: int(os.environ.get("VLLM_RBLN_DECODE_BATCH_BUCKET_LIMIT", 1)), } diff --git a/vllm_rbln/v1/attention/backends/flash_attention.py b/vllm_rbln/v1/attention/backends/flash_attention.py index f2ed8f68f..6e4050388 100644 --- a/vllm_rbln/v1/attention/backends/flash_attention.py +++ b/vllm_rbln/v1/attention/backends/flash_attention.py @@ -51,6 +51,28 @@ def flash_attention_naive_prefill_impl( slot_mapping: torch.Tensor, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Expected tensor shapes: + - q: [batch, n_kv_heads, n_groups, seq_len, head_dim] + Query states for multiple tokens + - k: [batch, n_kv_heads, 1, seq_len, head_dim] + Key states for current input + - v: [batch, n_kv_heads, 1, seq_len, head_dim] + Value states for current input + - kv_cache: [2, num_blocks, n_kv_heads, 1, partition_size, head_dim] + Key and value cache + - mask: [batch, 1, 1, seq_len, max_seq_len] + - seq_idx: [batch, num_partitions] + number of already cached tokens in each partition + - block_tables: [num_partitions,] for prefill, + [batch, num_partitions] for decode + - sinks: [n_heads, sink_len] (optional) + + Returns: + Tensor: attn_output: [batch, n_kv_heads, n_groups, seq_len, head_dim] + + batch size is assumed to be 1 for prefill. + """ if not envs.VLLM_RBLN_COMPILE_MODEL: # attn_weights = MM(q,kt) * scale # attn_weights = add(attn_weights + mask) @@ -174,39 +196,155 @@ def flash_causal_attention_naive_prefill_impl( slot_mapping: torch.Tensor, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if not envs.VLLM_RBLN_COMPILE_MODEL: - # attn_weights = MM(q,kt) * scale - # attn_weights = causal masked softmax(attn_weights) - # MM(attn_weights, v) - seq_len = q.size(-2) - s = seq_idx[0][0] - e = s + seq_len - # NOTE: this reference impl works only for single partition - block = block_tables[0].to(torch.int32) - k_state = kv_cache[0][block].unsqueeze(0).slice_scatter(k, - dim=3, - start=s, - end=e) - v_state = kv_cache[1][block].unsqueeze(0).slice_scatter(v, - dim=3, - start=s, - end=e) - kv_cache[0][block] = k_state.squeeze(0) - kv_cache[1][block] = v_state.squeeze(0) - attn_weights = torch.matmul(q, k_state.transpose(3, 4)) * scale - block_size = kv_cache.size(-2) - causal_mask = torch.triu(torch.ones(1, 1, 1, block_size, block_size), - diagonal=1) - causal_mask = causal_mask[:, :, :, s:e, :] - causal_mask = torch.where(causal_mask > 0, float('-inf'), - 0.0).to(attn_weights.dtype) - attn_weights = attn_weights + causal_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, v_state) - return attn_output - else: + """ + Expected tensor shapes: + - q: [batch, n_kv_heads, n_groups, seq_len, head_dim] + Query states for multiple tokens + - k: [batch, n_kv_heads, 1, seq_len, head_dim] + Key states for current input + - v: [batch, n_kv_heads, 1, seq_len, head_dim] + Value states for current input + - kv_cache: [2, num_blocks, n_kv_heads, 1, partition_size, head_dim] + Key and value cache + - seq_idx: [batch, num_partitions] + number of already cached tokens in each partition + - block_tables: [num_partitions,] for prefill, + [batch, num_partitions] for decode + - sinks: [n_heads, sink_len] (optional) + + Returns: + Tensor: attn_output: [batch, n_kv_heads, n_groups, seq_len, head_dim] + + batch size is assumed to be 1 for prefill. + """ + if envs.VLLM_RBLN_COMPILE_MODEL: return torch.empty_like(q) + # This is just reference to test vllm-rbln independently of the actual RBLN + # custom op implementation, so it implements simple non-flash attention. + + batch_size, n_kv_heads, n_groups, seq_len, head_dim = q.shape + partition_size = kv_cache.size(-2) + num_partitions = block_tables.shape[0] + + # Calculate the starting position (number of tokens already in cache) + # seq_idx contains tokens per partition that are already cached + cache_start_pos = int(seq_idx[0].sum().item()) + total_seq_len = cache_start_pos + seq_len + + # Step 1: Write KV cache + # We need to write seq_len new tokens starting at cache_start_pos + for p in range(num_partitions): + block_idx = block_tables[p].to(torch.int32) + + # Calculate how many tokens to write to this partition + partition_start = p * partition_size + partition_end = (p + 1) * partition_size + + # Tokens we're writing go from cache_start_pos to + # cache_start_pos + seq_len + write_start = max(cache_start_pos, partition_start) + write_end = min(cache_start_pos + seq_len, partition_end) + + if write_start >= write_end: + continue + + num_tokens_to_write = write_end - write_start + offset_in_partition = write_start - partition_start + offset_in_input = write_start - cache_start_pos + + k_slice = k[:, :, :, + offset_in_input:offset_in_input + num_tokens_to_write, :] + v_slice = v[:, :, :, + offset_in_input:offset_in_input + num_tokens_to_write, :] + + kv_cache[0, block_idx, :, :, offset_in_partition:offset_in_partition + + num_tokens_to_write, :] = k_slice.squeeze(0) + kv_cache[1, block_idx, :, :, offset_in_partition:offset_in_partition + + num_tokens_to_write, :] = v_slice.squeeze(0) + + # Step 2: Gather KV cache for the entire sequence + k_gathered = torch.zeros(batch_size, + n_kv_heads, + 1, + total_seq_len, + head_dim, + dtype=k.dtype, + device=k.device) + v_gathered = torch.zeros(batch_size, + n_kv_heads, + 1, + total_seq_len, + head_dim, + dtype=v.dtype, + device=v.device) + + gathered_pos = 0 + for p in range(num_partitions): + block_idx = block_tables[p].to(torch.int32) + + # Calculate how many tokens are in this partition after writing + partition_start = p * partition_size + tokens_in_partition = min(total_seq_len - partition_start, + partition_size) + + if tokens_in_partition <= 0: + break + + k_gathered[:, :, :, gathered_pos:gathered_pos + + tokens_in_partition, :] = kv_cache[ + 0, block_idx, :, :, :tokens_in_partition, :] + v_gathered[:, :, :, gathered_pos:gathered_pos + + tokens_in_partition, :] = kv_cache[ + 1, block_idx, :, :, :tokens_in_partition, :] + gathered_pos += tokens_in_partition + + # Step 3: Compute causal attention (with sinks, if any) + # attn_weights: [batch, n_kv_heads, n_groups, seq_len, total_seq_len] + attn_weights = torch.matmul(q, k_gathered.transpose(3, 4)) * scale + + # Create causal mask + # Query positions are from cache_start_pos to cache_start_pos + seq_len + query_positions = torch.arange(cache_start_pos, + cache_start_pos + seq_len, + device=q.device) + key_positions = torch.arange(total_seq_len, device=q.device) + + # Causal mask: query can only attend to keys at positions <= query position + causal_mask = query_positions.unsqueeze(1) >= key_positions.unsqueeze(0) + + # Convert to attention mask format + causal_mask = torch.where(causal_mask, 0.0, + float('-inf')).to(attn_weights.dtype) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + attn_weights = attn_weights + causal_mask + + # Apply attention sinks if provided + # sinks shape: [n_heads, sink_len] -> + # expand to [batch, n_kv_heads, n_groups, seq_len, sink_len] + if sinks is not None: + # sinks: [n_heads, sink_len] where n_heads = n_kv_heads * n_groups + sink_len = sinks.size(-1) + # Reshape sinks to match attention weight dimensions + sinks_expanded = sinks.view(n_kv_heads, n_groups, 1, sink_len) + sinks_expanded = sinks_expanded.expand(batch_size, n_kv_heads, + n_groups, seq_len, sink_len) + # Concatenate sink logits to attention weights + combined_logits = torch.cat([attn_weights, sinks_expanded], dim=-1) + # Stabilize softmax by subtracting max + combined_logits = combined_logits - combined_logits.max( + dim=-1, keepdim=True).values + probs = torch.nn.functional.softmax(combined_logits, dim=-1) + # Drop the sink probabilities before matmul with values + attn_weights = probs[..., :-sink_len] + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_output = torch.matmul(attn_weights, v_gathered) + + return attn_output + @torch.library.register_fake( "rbln_custom_ops::flash_causal_attention_naive_prefill") @@ -238,38 +376,124 @@ def flash_causal_attention_naive_decode_impl( slot_mapping: torch.Tensor, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if not envs.VLLM_RBLN_COMPILE_MODEL: - # NOTE: this reference impl works only for batch_size=1 - assert q.size(0) == 1 - seq_len = q.size(-2) - s = seq_idx[0][0] - e = s + seq_len - # NOTE: this reference impl works only for single partition - block = block_tables[0][0].to(torch.int32) - k_state = kv_cache[0][block].unsqueeze(0).slice_scatter(k, - dim=3, - start=s, - end=e) - v_state = kv_cache[1][block].unsqueeze(0).slice_scatter(v, - dim=3, - start=s, - end=e) - kv_cache[0][block] = k_state.squeeze(0) - kv_cache[1][block] = v_state.squeeze(0) - attn_weights = torch.matmul(q, k_state.transpose(3, 4)) * scale - block_size = kv_cache.size(-2) - causal_mask = torch.triu(torch.ones(1, 1, 1, block_size, block_size), - diagonal=1) - causal_mask = causal_mask[:, :, :, s:e, :] - causal_mask = torch.where(causal_mask > 0, float('-inf'), - 0.0).to(attn_weights.dtype) - attn_weights = attn_weights + causal_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, v_state) - return attn_output - else: + if envs.VLLM_RBLN_COMPILE_MODEL: return torch.empty_like(q) + # This is just reference to test vllm-rbln independently of the actual RBLN + # custom op implementation, so it implements simple non-flash attention. + + batch_size, n_kv_heads, n_groups, seq_len, head_dim = q.shape + partition_size = kv_cache.size(-2) + num_partitions = block_tables.shape[1] + + outputs = [] + for b in range(batch_size): + # Calculate the starting position (number of tokens already in cache) + cache_start_pos = int(seq_idx[b].sum().item()) + total_seq_len = cache_start_pos + seq_len # seq_len is 1 for decode + + if total_seq_len == 0: + outputs.append(torch.zeros_like(q[b:b + 1])) + continue + + # Step 1: Write KV cache + # Find which partition to write to + for p in range(num_partitions): + block_idx = block_tables[b, p].to(torch.int32) + + partition_start = p * partition_size + partition_end = (p + 1) * partition_size + + write_start = max(cache_start_pos, partition_start) + write_end = min(cache_start_pos + seq_len, partition_end) + + if write_start >= write_end: + continue + + num_tokens_to_write = write_end - write_start + offset_in_partition = write_start - partition_start + offset_in_input = write_start - cache_start_pos + + k_slice = k[b:b + 1, :, :, offset_in_input:offset_in_input + + num_tokens_to_write, :] + v_slice = v[b:b + 1, :, :, offset_in_input:offset_in_input + + num_tokens_to_write, :] + + kv_cache[0, block_idx, :, :, + offset_in_partition:offset_in_partition + + num_tokens_to_write, :] = k_slice.squeeze(0) + kv_cache[1, block_idx, :, :, + offset_in_partition:offset_in_partition + + num_tokens_to_write, :] = v_slice.squeeze(0) + + # Step 2: Gather KV cache for the entire sequence + k_gathered = torch.zeros(1, + n_kv_heads, + 1, + total_seq_len, + head_dim, + dtype=k.dtype, + device=k.device) + v_gathered = torch.zeros(1, + n_kv_heads, + 1, + total_seq_len, + head_dim, + dtype=v.dtype, + device=v.device) + + gathered_pos = 0 + for p in range(num_partitions): + block_idx = block_tables[b, p].to(torch.int32) + + partition_start = p * partition_size + tokens_in_partition = min(total_seq_len - partition_start, + partition_size) + + if tokens_in_partition <= 0: + break + + k_gathered[:, :, :, gathered_pos:gathered_pos + + tokens_in_partition, :] = kv_cache[ + 0, block_idx, :, :, :tokens_in_partition, :] + v_gathered[:, :, :, gathered_pos:gathered_pos + + tokens_in_partition, :] = kv_cache[ + 1, block_idx, :, :, :tokens_in_partition, :] + gathered_pos += tokens_in_partition + + # Step 3: Compute causal attention (with sinks, if any) + q_b = q[b:b + 1] + attn_weights = torch.matmul(q_b, k_gathered.transpose(3, 4)) * scale + + # For decode, query is at position total_seq_len - 1 + # It can attend to all previous positions (0 to total_seq_len - 1) + # So no causal masking needed for decode + + # Apply attention sinks if provided + # sinks shape: [n_heads, sink_len] -> + # expand to [1, n_kv_heads, n_groups, seq_len, sink_len] + if sinks is not None: + sink_len = sinks.size(-1) + # Reshape sinks to match attention weight dimensions + sinks_expanded = sinks.view(n_kv_heads, n_groups, 1, sink_len) + sinks_expanded = sinks_expanded.expand(1, n_kv_heads, n_groups, + seq_len, sink_len) + # Concatenate sink logits to attention weights + combined_logits = torch.cat([attn_weights, sinks_expanded], dim=-1) + # Stabilize softmax by subtracting max + combined_logits = combined_logits - combined_logits.max( + dim=-1, keepdim=True).values + probs = torch.nn.functional.softmax(combined_logits, dim=-1) + # Drop the sink probabilities before matmul with values + attn_weights = probs[..., :-sink_len] + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_output = torch.matmul(attn_weights, v_gathered) + outputs.append(attn_output) + + return torch.cat(outputs, dim=0) + @torch.library.register_fake( "rbln_custom_ops::flash_causal_attention_naive_decode") @@ -304,13 +528,13 @@ def sliding_window_attention_naive_prefill_impl( ) -> torch.Tensor: """ Expected tensor shapes: - - q: [batch, n_heads, n_groups, seq_len, head_dim] + - q: [batch, n_kv_heads, n_groups, seq_len, head_dim] Query states for multiple tokens - - k: [batch, n_heads, 1, seq_len, head_dim] + - k: [batch, n_kv_heads, 1, seq_len, head_dim] Key states for current input - - v: [batch, n_heads, 1, seq_len, head_dim] + - v: [batch, n_kv_heads, 1, seq_len, head_dim] Value states for current input - - kv_cache: [2, num_blocks, n_heads, 1, window_size, head_dim] + - kv_cache: [2, num_blocks, n_kv_heads, 1, window_size, head_dim] Key and value cache - cache_seq_len: [batch, 1] number of tokens already cached @@ -318,9 +542,10 @@ def sliding_window_attention_naive_prefill_impl( ending position after insertion (cache_seq_len + query_len) - scale: []. Attention scale factor - block_tables: [batch] for prefill, [batch, 1] for decode + - sinks: [n_heads, sink_len] (optional) Returns: - Tensor: attn_output: [batch, n_heads, n_groups, seq_len, head_dim] + Tensor: attn_output: [batch, n_kv_heads, n_groups, seq_len, head_dim] batch size is assumed to be 1 for prefill. """ @@ -373,7 +598,21 @@ def sliding_window_attention_naive_prefill_impl( mask = torch.where(mask > 0, 0.0, float('-inf')).to(attn_weights.dtype) attn_weights = attn_weights + mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + if sinks is not None: + sink_len = sinks.size(-1) + n_kv_heads = q.size(1) + n_groups = q.size(2) + sinks_expanded = sinks.view(n_kv_heads, n_groups, 1, sink_len) + sinks_expanded = sinks_expanded.expand(1, n_kv_heads, n_groups, + seq_len, sink_len) + combined_logits = torch.cat([attn_weights, sinks_expanded], dim=-1) + combined_logits = combined_logits - combined_logits.max( + dim=-1, keepdim=True).values + probs = torch.nn.functional.softmax(combined_logits, dim=-1) + attn_weights = probs[..., :-sink_len] + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, v_cache_curr) @@ -475,9 +714,25 @@ def sliding_window_attention_naive_decode_impl( mask = torch.where(mask > 0, 0.0, float('-inf')).to(attn_weights.dtype) attn_weights = attn_weights + mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + if sinks is not None: + sink_len = sinks.size(-1) + n_kv_heads = q.size(1) + n_groups = q.size(2) + sinks_expanded = sinks.view(n_kv_heads, n_groups, 1, sink_len) + sinks_expanded = sinks_expanded.expand(1, n_kv_heads, n_groups, 1, + sink_len) + combined_logits = torch.cat([attn_weights, sinks_expanded], dim=-1) + combined_logits = combined_logits - combined_logits.max( + dim=-1, keepdim=True).values + probs = torch.nn.functional.softmax(combined_logits, dim=-1) + attn_weights = probs[..., :-sink_len] + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, v_cache_curr) outputs.append(attn_output) + return torch.cat(outputs, dim=0) @@ -633,6 +888,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], get_current_vllm_config().model_config.enforce_eager) self.is_causal = envs.VLLM_RBLN_FLASH_CAUSAL_ATTN + self.is_batch_attention_opt = envs.VLLM_RBLN_BATCH_ATTN_OPT def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -645,6 +901,7 @@ def build( fast_build: bool = False, num_tokens=None, positions=None, + batch_pad=None, ) -> RBLNFlashAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -680,6 +937,8 @@ def build( assert num_tokens is not None, ( "num_tokens is required for RBLN Attention Backend") + assert batch_pad is not None, ( + "batch_pad is required for RBLN Attention Backend") is_prefills = ( common_attn_metadata.num_computed_tokens_cpu[:num_reqs].numpy() < num_tokens[:num_reqs] - 1) @@ -718,13 +977,13 @@ def build( attn_masks = attn_masks.to(self.device) else: # batch padding - seq_lens_tensor = rbln_utils.pad( - seq_lens_tensor, 0, self.scheduler_config.max_num_seqs) - block_tables_tensor = rbln_utils.pad( - block_tables_tensor, 0, self.scheduler_config.max_num_seqs) + seq_idx = rbln_utils.pad(seq_idx, 0, batch_pad) + seq_lens_tensor = rbln_utils.pad(seq_lens_tensor, 0, batch_pad) + block_tables_tensor = rbln_utils.pad(block_tables_tensor, 0, + batch_pad) if not self.is_causal: decode_attention_mask = torch.zeros( - self.scheduler_config.max_num_seqs, + batch_pad, 1, 1, 1, @@ -751,18 +1010,19 @@ def build( max=sliding_window) cache_offsets = cache_seq_lens + query_lens if not is_prefills[0]: - cache_seq_lens = rbln_utils.pad( - cache_seq_lens, 0, self.scheduler_config.max_num_seqs) - cache_offsets = rbln_utils.pad( - cache_offsets, 0, self.scheduler_config.max_num_seqs) + cache_seq_lens = rbln_utils.pad(cache_seq_lens, 0, batch_pad) + cache_offsets = rbln_utils.pad(cache_offsets, 0, batch_pad) local_block_tables = block_tables_tensor[..., :1] + # seq_idx(batch attention opt decode) - [B, 1], for each batch, have sequence offset + # seq_lens_tensor(otherwise) - [B, P], have dynamic size for each partition attn_metadata = RBLNFlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=query_max_seq_len, - seq_lens=seq_lens_tensor.to(self.device), + seq_lens=seq_lens_tensor.to(self.device) + if not self.is_batch_attention_opt or is_prefills[0] else seq_idx.to(self.device), block_tables=block_tables_tensor.to(self.device), slot_mapping=slot_mapping, use_cascade=False, @@ -859,6 +1119,9 @@ def __init__( if len(self.sinks.size()) == 1: self.sinks = self.sinks[:, None] + self.is_causal = envs.VLLM_RBLN_FLASH_CAUSAL_ATTN + self.is_batch_attention_opt = envs.VLLM_RBLN_BATCH_ATTN_OPT + def forward( self, layer: torch.nn.Module, @@ -973,7 +1236,9 @@ def forward( self.sinks, ) # actually non-flash paged attention DOES NOT use slot_mapping - elif envs.VLLM_RBLN_FLASH_CAUSAL_ATTN: + elif self.is_causal: + # batched attention - seq_lens[B, 1] == seq_idx, original sequence index + # otherwise - seq_lens[B, P] == dyn_size_for_partitions, dynamic size for each partition if q_len == 1: attn_output = torch.ops.rbln_custom_ops.flash_causal_attention_naive_decode( # noqa: E501 query, diff --git a/vllm_rbln/v1/worker/rbln_model_runner.py b/vllm_rbln/v1/worker/rbln_model_runner.py index a9bc2d871..41417e60b 100644 --- a/vllm_rbln/v1/worker/rbln_model_runner.py +++ b/vllm_rbln/v1/worker/rbln_model_runner.py @@ -90,6 +90,7 @@ import vllm_rbln.rbln_envs as envs import vllm_rbln.utils as rbln_utils +from vllm_rbln.forward_context import RBLNDPMetadata from vllm_rbln.logger import init_logger from vllm_rbln.lora.inputs import LoRAInputs from vllm_rbln.lora.mask import LoRAMask @@ -167,6 +168,15 @@ class ExecuteModelState(NamedTuple): kv_connector_output: KVConnectorOutput | None +class DummyRunState(NamedTuple): + """Input state for dummy run.""" + + attn_metadata: dict[int, dict[str, Any]] + num_input_tokens: int + input_ids: dict[int, torch.Tensor] + positions: dict[int, torch.Tensor] + + class RBLNModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( @@ -383,7 +393,8 @@ def __init__( (3, self.max_num_tokens + 1), dtype=torch.int64) # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.prefill_intermediate_tensors: Optional[IntermediateTensors] = None + self.decode_intermediate_tensors: dict[int, IntermediateTensors] = {} # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context @@ -457,6 +468,13 @@ def __init__( self.bucketing_manager.decode_batch_buckets) self.performance_tracker = None + + self.dummy_run_state: DummyRunState | None = None + + self.specialized_moe_decode = parallel_config.data_parallel_size > 1 \ + and envs.VLLM_RBLN_SPECIALIZE_MOE_DECODE + + def _enable_performance_tracker(self): if envs.VLLM_RBLN_METRICS: self.performance_tracker = PerformanceTracker() self.performance_tracker.register_cleanup() @@ -931,6 +949,7 @@ def _prepare_inputs( self, scheduler_output: SchedulerOutput, num_scheduled_tokens: np.ndarray, + num_padded_tokens: int | None = None, ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray, Optional[CommonAttentionMetadata], int]: """ @@ -1077,6 +1096,12 @@ def _prepare_inputs( self.num_accepted_tokens.copy_to_gpu() max_num_scheduled_tokens = int(num_scheduled_tokens.max()) + batch_bucket_size = \ + self.bucketing_manager.find_decode_batch_bucket(num_reqs) + + (batch_bucket_size, num_padded_tokens, num_tokens_across_dp) = \ + self.get_dp_padding(total_num_scheduled_tokens, batch_bucket_size, + num_padded_tokens, bool(self.is_prefills()[0])) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1158,6 +1183,7 @@ def _prepare_inputs( extra_attn_metadata_args["num_tokens"] = \ self.input_batch.num_tokens_no_spec extra_attn_metadata_args["positions"] = self.positions.cpu + extra_attn_metadata_args["batch_pad"] = batch_bucket_size attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, @@ -1176,7 +1202,8 @@ def _prepare_inputs( return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + max_num_scheduled_tokens, batch_bucket_size, num_padded_tokens, + num_tokens_across_dp) def _compile_model(self, model): TP = get_tp_group() @@ -1385,28 +1412,47 @@ def sync_and_slice_intermediate_tensors( for k, v in intermediate_tensors.items() }) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding( + self, + num_tokens: int, + batch_bucket_size: int, + num_padded_tokens: int | None = None, + is_prefill: bool = False + ) -> tuple[int, Optional[int], Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size - # dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - max_tokens_across_dp_cpu = num_tokens - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + if dp_size == 1: + assert num_padded_tokens is None, \ + "num_padded_tokens should not be applied for non-DP case" + return batch_bucket_size, num_padded_tokens, None + + if num_padded_tokens is not None: + assert self.specialized_moe_decode, \ + "num_padded_tokens is only supported when " \ + "specialized MOE decode is enabled" + assert num_padded_tokens == self.max_num_batched_tokens, \ + "num_padded_tokens should be equal to max_num_batched_tokens" + assert not is_prefill, \ + "num_padded_tokens is only supported for decode stage" + num_tokens_across_dp_cpu = RBLNDPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + return (batch_bucket_size, num_padded_tokens, + num_tokens_across_dp_cpu) + + num_tokens_across_dp_cpu, max_decode_tokens = \ + RBLNDPMetadata.num_tokens_across_dp_with_max_decode_tokens( + num_tokens, dp_size, dp_rank, is_prefill) + + any_prefill = max_decode_tokens is None + if any_prefill or not self.specialized_moe_decode: + num_padded_tokens = self.max_num_batched_tokens + else: + batch_bucket_size = self.bucketing_manager.find_decode_batch_bucket( + max_decode_tokens) + num_padded_tokens = batch_bucket_size + + return batch_bucket_size, num_padded_tokens, num_tokens_across_dp_cpu def _pool( self, @@ -1453,7 +1499,6 @@ def _pool( def _preprocess( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[int, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: @@ -1470,8 +1515,7 @@ def _preprocess( num_input_tokens = num_scheduled_tokens # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad + # NOTE(RBLN): RBLN handles DP padding in _prepare_inputs # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1512,12 +1556,6 @@ def _preprocess( else: positions = self.positions.gpu[:num_input_tokens] - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - -1, -1, intermediate_tensors, True) - if (self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs): encoder_inputs = self._extract_encoder_inputs(scheduler_output) @@ -1525,11 +1563,9 @@ def _preprocess( return ( num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, - intermediate_tensors, model_kwargs, ) @@ -1584,14 +1620,14 @@ def warm_up_model(self) -> None: task=self.get_supported_pooling_tasks()[0]) if self.is_pooling_model else None, ) - self._execute_dummy_requests(dummy_prefill_requests, - dummy_prefill_num_scheduled_tokens, - num_kv_cache_groups) + so, cso = self._make_dummy_scheduler_outputs( + dummy_prefill_requests, dummy_prefill_num_scheduled_tokens, + num_kv_cache_groups) + self._execute_dummy_requests(so, cso, + self.prefill_intermediate_tensors) # compile decode graph for batch_bucket_size in self.bucketing_manager.decode_batch_buckets: - logger.info("Warm up decode graph with batch_size: %d", - batch_bucket_size) decode_max_seq_len = self.max_model_len dummy_decode_requests = [] @@ -1615,10 +1651,21 @@ def warm_up_model(self) -> None: if self.is_pooling_model else None, num_speculative_tokens=num_speculative_tokens, ) - - self._execute_dummy_requests(dummy_decode_requests, - dummy_decode_num_scheduled_tokens, - num_kv_cache_groups) + so, cso = self._make_dummy_scheduler_outputs( + dummy_decode_requests, dummy_decode_num_scheduled_tokens, + num_kv_cache_groups) + current_intermediate_tensors = \ + self.decode_intermediate_tensors.get(batch_bucket_size) + assert current_intermediate_tensors is not None + + if self.specialized_moe_decode: + self._execute_dummy_requests( + so, + cso, + current_intermediate_tensors, + num_padded_tokens=self.max_num_batched_tokens) + + self._execute_dummy_requests(so, cso, current_intermediate_tensors) def _add_dummy_requests( self, @@ -1630,6 +1677,7 @@ def _add_dummy_requests( sampling_params: Optional[SamplingParams] = None, pooling_params: Optional[PoolingParams] = None, num_speculative_tokens: int = 0, + block_id: int = 0, ) -> None: num_blocks = round_up( total_tokens, @@ -1642,7 +1690,7 @@ def _add_dummy_requests( mm_features=[], sampling_params=sampling_params, pooling_params=pooling_params, - block_ids=([0] * num_blocks, ) * num_kv_cache_groups, + block_ids=([block_id] * num_blocks, ) * num_kv_cache_groups, num_computed_tokens=num_computed_tokens, lora_request=None, ) @@ -1651,9 +1699,10 @@ def _add_dummy_requests( if total_tokens - num_computed_tokens == 0 \ else total_tokens - num_computed_tokens - def _execute_dummy_requests(self, requests: list[NewRequestData], - num_scheduled_tokens: dict[str, int], - num_kv_cache_groups: int) -> None: + def _make_dummy_scheduler_outputs( + self, requests: list[NewRequestData], + num_scheduled_tokens: dict[str, int], num_kv_cache_groups: int + ) -> tuple[SchedulerOutput, SchedulerOutput]: sched_output = SchedulerOutput( scheduled_new_reqs=requests, scheduled_cached_reqs=CachedRequestData.make_empty(), @@ -1679,40 +1728,395 @@ def _execute_dummy_requests(self, requests: list[NewRequestData], free_encoder_mm_hashes=[], kv_connector_metadata=None, ) + return sched_output, cleanup_sched_output + + def _execute_dummy_requests( + self, + sched_output: SchedulerOutput, + cleanup_sched_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors, + num_padded_tokens: int | None = None, + ) -> None: if get_pp_group().is_first_rank: intermediate_tensors = None - else: - # make RBLN decode dummy intermediate tensors - # FIXME - based on assumption, multiple batch decode - batch_size = len(requests) - assert batch_size >= 1 - num_computed_tokens = requests[0].num_computed_tokens - if num_computed_tokens == 0: - # FIXME(RBLN) - # prefill - single batch prefill - assert batch_size == 1 - prompt_token_ids = requests[0].prompt_token_ids - seq_len = len(prompt_token_ids) - else: - # decode - single token prompt - seq_len = 1 - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=batch_size * seq_len, - dtype=self.model_config.dtype, - device=self.device)) - - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - batch_size, seq_len, None, False) - - output = self.execute_model(sched_output, intermediate_tensors) + + output = self.execute_model(sched_output, intermediate_tensors, + num_padded_tokens) if output is None: self.sample_tokens(None) - output = self.execute_model(cleanup_sched_output, intermediate_tensors) + output = self.execute_model(cleanup_sched_output, intermediate_tensors, + num_padded_tokens) if output is None: self.sample_tokens(None) - self.intermediate_tensors = None + + def _update_dummy_states(self, scheduler_output: SchedulerOutput, + input_batch: InputBatch) -> None: + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + generator = None + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_features=new_req_data.mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + reqs_to_add.append(req_state) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + input_batch.add_request(request) + + # Refresh batch metadata with any pending updates. + input_batch.refresh_metadata() + + def _prepare_dummy_inputs( + self, + scheduler_output: SchedulerOutput, + input_batch: InputBatch, + ) -> DummyRunState: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_reqs = input_batch.num_reqs + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + # Get positions. + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add(input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + positions_np = positions_np.copy() + positions = self.positions.cpu.clone() + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * input_batch.token_ids_cpu.shape[1]) + + input_ids = self.input_ids.cpu.clone() + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=input_ids[:total_num_scheduled_tokens]) + + input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + query_start_loc_np = self.query_start_loc.np.copy() + query_start_loc_np[0] = 0 + query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1]) + query_start_loc = torch.tensor(query_start_loc_np, + dtype=torch.int32)[:num_reqs + 1] + + seq_lens_np = self.seq_lens.np.copy() + seq_lens_np[:num_reqs] = ( + input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + # Fill unused with 0 for full cuda graph mode. + seq_lens_np[num_reqs:].fill(0) + seq_lens = torch.tensor(seq_lens_np, dtype=torch.int32)[:num_reqs] + max_seq_len = seq_lens_np[:num_reqs].max().item() + + # TODO: support spec_decode + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if use_spec_decode: + raise NotImplementedError( + "Spec decode is not supported for DP dummy run") + + logits_indices = query_start_loc[1:] - 1 + + logits_indices_padded = None + + # Used in the below loop. + query_start_loc_cpu = query_start_loc + seq_lens_cpu = seq_lens + num_computed_tokens_cpu = ( + input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + + attn_metadata_bucket: dict[int, dict[str, Any]] = {} + input_ids_bucket: dict[int, torch.Tensor] = {} + positions_bucket: dict[int, torch.Tensor] = {} + + for batch_bucket_size in self.bucketing_manager.decode_batch_buckets: + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and + # make layers in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + encoder_seq_lens = None + + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + raise NotImplementedError( + "Encoder-only attention is not supported for " + "DP dummy run") + else: + blk_table = input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor(num_reqs) + slot_mapping = \ + blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + + # Fill unused with -1. + # Needed for reshape_and_cache in full cuda graph mode. + # `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[total_num_scheduled_tokens: + total_num_scheduled_tokens].fill_(-1) + blk_table_tensor[ + num_reqs:total_num_scheduled_tokens].fill_(-1) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + causal=True, + encoder_seq_lens=encoder_seq_lens, + ) + + for attn_group in self.attn_groups[kv_cache_group_id]: + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = attn_group.get_metadata_builder() + if self.cascade_attn_enabled: + raise NotImplementedError( + "Cascade attention is not supported " + "for DP dummy run") + + extra_attn_metadata_args = {} + + if isinstance(builder, RBLNFlashAttentionMetadataBuilder): + extra_attn_metadata_args["num_tokens"] = \ + input_batch.num_tokens + extra_attn_metadata_args["positions"] = positions + extra_attn_metadata_args[ + "batch_pad"] = batch_bucket_size + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + for attn_metadatum in attn_metadata.values(): + attn_metadatum.kv_caches = self.kv_caches + num_input_tokens = total_num_scheduled_tokens + input_ids = input_ids[:num_input_tokens] + positions = positions[:num_input_tokens] + + input_ids = input_ids.view(num_reqs, -1).to(torch.long) + positions = positions.view(num_reqs, -1) + + # decode batch padding + input_ids = rbln_utils.pad(input_ids, 0, batch_bucket_size) + positions = rbln_utils.pad(positions, -2, batch_bucket_size) + + attn_metadata_bucket[batch_bucket_size] = attn_metadata + input_ids_bucket[batch_bucket_size] = input_ids + positions_bucket[batch_bucket_size] = positions + + return DummyRunState(attn_metadata=attn_metadata_bucket, + num_input_tokens=num_input_tokens, + input_ids=input_ids_bucket, + positions=positions_bucket) + + def _prepare_dummy_input_batch(self) -> InputBatch: + logits_processors = self.model_config.logits_processors + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = ( + tuple(logits_processors) if logits_processors is not None else ()) + dummy_input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, + self.device, + self.pin_memory, + self.is_pooling_model, + custom_logitsprocs, + ), + # We currently don't know whether a particular custom logits + # processor uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), + is_pooling_model=self.is_pooling_model, + cp_kv_cache_interleave_size=self.parallel_config. + cp_kv_cache_interleave_size, + ) + + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in self.kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec) + ] + kernel_block_sizes = self.kernel_block_sizes + + if block_sizes != [ + self.cache_config.block_size + ] or kernel_block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + dummy_input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=max(self.max_model_len, self.max_encoder_len), + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=dummy_input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + ) + + return dummy_input_batch + + @torch.inference_mode() + def prepare_dummy_run(self) -> None: + # TODO: support spec_decode, pooling, mrope, lora, + # and encoder-only attention + if self.is_pooling_model: + raise NotImplementedError( + "Pooling model is not supported for DP dummy run") + if self.uses_mrope: + raise NotImplementedError( + "M-RoPE is not supported for DP dummy run") + if self.lora_config: + raise NotImplementedError("LoRA is not supported for DP dummy run") + + num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) + dummy_run_requests = [] + dummy_run_num_scheduled_tokens = {} + self._add_dummy_requests( + requests=dummy_run_requests, + num_scheduled_tokens=dummy_run_num_scheduled_tokens, + total_tokens=1, + num_computed_tokens=1, + num_kv_cache_groups=num_kv_cache_groups, + sampling_params=None if self.is_pooling_model else SamplingParams( + temperature=0.0), + pooling_params=PoolingParams( + task=self.get_supported_pooling_tasks()[0]) + if self.is_pooling_model else None, + block_id=self.cache_config.num_gpu_blocks - 1, + ) + dummy_run_scheduler_output, _ = self._make_dummy_scheduler_outputs( + dummy_run_requests, dummy_run_num_scheduled_tokens, + num_kv_cache_groups) + + dummy_input_batch = self._prepare_dummy_input_batch() + self._update_dummy_states(dummy_run_scheduler_output, + dummy_input_batch) + self.dummy_run_state = self._prepare_dummy_inputs( + dummy_run_scheduler_output, dummy_input_batch) + + @torch.inference_mode() + def dummy_run(self) -> None: + (attn_metadata, num_input_tokens, input_ids, + positions) = self.dummy_run_state + + (batch_bucket_size, num_padded_tokens, + num_tokens_across_dp) = self.get_dp_padding( + num_input_tokens, self.bucketing_manager.decode_batch_buckets[0]) + + attn_metadata = attn_metadata.get(batch_bucket_size) + input_ids = input_ids.get(batch_bucket_size) + positions = positions.get(batch_bucket_size) + assert attn_metadata is not None \ + and input_ids is not None \ + and positions is not None, \ + "attn_metadata, input_ids, and positions should be defined" \ + f" for batch_bucket_size: {batch_bucket_size}" + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = \ + self.decode_intermediate_tensors.get(batch_bucket_size) + assert intermediate_tensors is not None + + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + num_padded_tokens=num_padded_tokens, + ): + token_indices = None + inputs_embeds = None + model_kwargs = dict[str, Any]({}) + + _ = self.model_executable( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + selected_token_indices=token_indices, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", @@ -1852,6 +2256,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, + num_padded_tokens: int | None = None, ) -> Union[ModelRunnerOutput, IntermediateTensors, None]: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " @@ -1883,18 +2288,17 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output, - num_scheduled_tokens_np) + max_query_len, batch_bucket_size, num_padded_tokens, + num_tokens_across_dp) = self._prepare_inputs( + scheduler_output, num_scheduled_tokens_np, num_padded_tokens) ( num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, - intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors) + ) = self._preprocess(scheduler_output) # Padding for speculative decoding # in case of that all requests are not scheduled equally. @@ -1920,6 +2324,7 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, + num_padded_tokens=num_padded_tokens, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output): @@ -1931,6 +2336,7 @@ def execute_model( # we must resolve the batch dimension. input_ids = input_ids.view(num_reqs, -1).to(torch.long) positions = positions.view(num_reqs, -1) + is_prefills = self.is_prefills() token_indices = None @@ -1951,9 +2357,6 @@ def execute_model( input_ids = rbln_utils.pad(input_ids, -1, prefill_size) positions = rbln_utils.pad(positions, -1, prefill_size) else: - batch_bucket_size = \ - self.bucketing_manager.find_decode_batch_bucket( - self.input_batch.num_reqs) # decode batch padding input_ids = rbln_utils.pad(input_ids, 0, batch_bucket_size) positions = rbln_utils.pad(positions, -2, batch_bucket_size) @@ -1983,7 +2386,7 @@ def execute_model( sampler_indices_padded = create_sampler_indices_padded( lora_ids, self.lora_manager._adapter_manager.lora_index_to_id, - self.max_num_seqs, + batch_bucket_size, is_prefills[0], self.lora_config.max_loras, self.device, @@ -2530,6 +2933,61 @@ def model_wrapper( compiled_graph = self._compile_model(model_wrapper) self.model_executable = compiled_graph + distributed_executor_backend = \ + self.vllm_config.parallel_config.distributed_executor_backend + if distributed_executor_backend == "ray": + self._prepare_prefill_intermediate_tensors() + for batch_bucket_size \ + in self.bucketing_manager.decode_batch_buckets: + self._prepare_decode_intermediate_tensors(batch_bucket_size) + else: + with torch.inference_mode(): + self._prepare_prefill_intermediate_tensors() + for batch_bucket_size \ + in self.bucketing_manager.decode_batch_buckets: + self._prepare_decode_intermediate_tensors( + batch_bucket_size) + + def _prepare_prefill_intermediate_tensors(self) -> None: + + def _reshape( + batch_size: int, seq_len: int, + intermediate_tensors: IntermediateTensors + ) -> IntermediateTensors: + return IntermediateTensors({ + k: v.view(batch_size, seq_len, -1) + for k, v in intermediate_tensors.items() + }) + + batch_size = self.max_prefill_batch_size + seq_len = self.max_num_batched_tokens + self.prefill_intermediate_tensors = _reshape( + batch_size, seq_len, + self.model.make_empty_intermediate_tensors( + batch_size=batch_size * seq_len, + dtype=self.model_config.dtype, + device=self.device)) + + def _prepare_decode_intermediate_tensors(self, batch_bucket_size) -> None: + + def _reshape( + batch_size: int, seq_len: int, + intermediate_tensors: IntermediateTensors + ) -> IntermediateTensors: + return IntermediateTensors({ + k: v.view(batch_size, seq_len, -1) + for k, v in intermediate_tensors.items() + }) + + batch_size = batch_bucket_size + seq_len = 1 + self.decode_intermediate_tensors[batch_bucket_size] = _reshape( + batch_size, seq_len, + self.model.make_empty_intermediate_tensors( + batch_size=batch_size * seq_len, + dtype=self.model_config.dtype, + device=self.device)) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -3189,6 +3647,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # we will return kernel_block_size 64 and split the 256-token-block to # 4 blocks with 64 tokens each. kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + self.kernel_block_sizes = kernel_block_sizes # create metadata builders self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) diff --git a/vllm_rbln/v1/worker/rbln_worker.py b/vllm_rbln/v1/worker/rbln_worker.py index 7a80fd745..ed64fe8d3 100644 --- a/vllm_rbln/v1/worker/rbln_worker.py +++ b/vllm_rbln/v1/worker/rbln_worker.py @@ -30,7 +30,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec, FullAttentionSpec from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats @@ -67,11 +67,10 @@ def __init__( ) self.device = torch.device(current_platform.device_type) - if self.parallel_config.distributed_executor_backend == "ray": - logger.info( - "Running on Ray backend. Skipping device env var setup.") - else: - self._init_device_env() + self.local_world_size = (self.parallel_config.world_size // + envs.VLLM_RBLN_NUM_RAY_NODES) + + self._init_device_env() if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -125,28 +124,34 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_cpu_blocks = num_cpu_blocks def _init_device_env(self) -> None: - world_size = self.parallel_config.world_size + world_size = self.local_world_size env_var = current_platform.device_control_env_var total_device_count = world_size * envs.VLLM_RBLN_TP_SIZE if env_var not in os.environ: - device_ids = [str(i) for i in range(total_device_count)] + dev_begin = total_device_count * \ + self.parallel_config.data_parallel_rank + dev_end = dev_begin + total_device_count + device_ids = [str(i) for i in range(dev_begin, dev_end)] + start_idx = self.local_rank * envs.VLLM_RBLN_TP_SIZE + end_idx = start_idx + envs.VLLM_RBLN_TP_SIZE + selected_devices = ",".join(device_ids[start_idx:end_idx]) else: device_ids = os.environ[env_var].split(",") - - # This check is only valid for single node mp backends, invalid for ray - # ex) node#0 : RBLN_DEVICES=0,1 - # node#1 : RBLN_DEVICES=2,3 - distributed_backend = self.parallel_config.distributed_executor_backend - if distributed_backend == "mp" and len( - device_ids) < total_device_count: - raise RuntimeError(f"{env_var} has devices {device_ids}" - f" but required {total_device_count}") - - start_idx = self.local_rank * envs.VLLM_RBLN_TP_SIZE - end_idx = start_idx + envs.VLLM_RBLN_TP_SIZE - selected_devices = ",".join(device_ids[start_idx:end_idx]) + assert len(device_ids) == world_size, \ + f"device_ids: {device_ids} " \ + f"should have device count: {world_size}" + try: + device_id = int(device_ids[self.local_rank]) + start_idx = device_id * envs.VLLM_RBLN_TP_SIZE + end_idx = start_idx + envs.VLLM_RBLN_TP_SIZE + device_ids = [str(i) for i in range(start_idx, end_idx)] + selected_devices = ",".join(device_ids) + except ValueError as e: + raise ValueError( + f"device_ids: {device_ids} should be a list of integers") \ + from e os.environ[env_var] = selected_devices logger.info( @@ -177,6 +182,52 @@ def load_model(self): @torch.inference_mode() def determine_available_memory(self) -> int: + params_dict = dict(self.model_runner.model.named_parameters()) + n_model_attentions = 0 + n_model_experts = 0 + device_name = current_platform.get_device_name().lower() + assert "rbln" in device_name + if "ca" in device_name: + # consider RSD size for ATOM + num_runtimes = 2 * envs.VLLM_RBLN_TP_SIZE + elif "cr" in device_name: + # single device == Quad chiplet + num_runtimes = 2 * 4 + else: + assert False, "invalid RBLN architecture, candidates = [ATOM(ca), REBEL(cr)]" + + if self.model_config.quantization is not None: + # FIXME(RBLN) - for now, mxfp4 quantization is only supported + assert self.model_config.quantization == "mxfp4" + if "ca" in device_name: + # ATOM DOES NOT support mxfp4 quantization, handled by bf16 + nbits_per_param = 16 + # mlp weight scale is merged into params + # FIXME(RBLN) - expert scale merged into expert weight param + # ratio scale vs weight = 1 : 16 + ratio = 16 / 17 + elif "cr" in device_name: + # REBEL can support mxfp4 quantization + nbits_per_param = 4 + ratio = 1 + else: + assert False, "invalid RBLN architecture, candidates = [ATOM(ca), REBEL(cr)]" + + # pack 2 mxfp4 elems into single uint8 elem + packed_num_elems = 8 // 4 + else: + nbits_per_param = 16 + packed_num_elems = 1 + ratio = 1 + for key, value in params_dict.items(): + if value.dtype == torch.bfloat16: + n_model_attentions += value.numel() + else: + # quantized params is handled + n_model_experts += value.numel() * packed_num_elems * ratio + + # NOTE - model parallel(tp, dp, ep, pp) already applied into model params + n_model_params = n_model_attentions + n_model_experts block_size = self.cache_config.block_size # This function comes from optimum-rbln. @@ -186,11 +237,21 @@ def determine_available_memory(self) -> int: parallel_config=self.parallel_config, kvcache_block_size=block_size, # quantization : 4 (This is an ad-hoc value. Need to fix it) - nbits_per_param=16 if not self.model_config.quantization else 4, - n_model_params=sum(p.numel() - for p in self.model_runner.model.parameters()), - # 2 : 1 for prefill and decode each - num_runtimes=2) + nbits_per_param=nbits_per_param, + n_model_params=n_model_params, + num_runtimes=num_runtimes) + + # NOTE - adjust max_num_blocks considering swa block sharing + # max_num_blocks - based on FullAttentionSpec for model + # SHOULD adjust num blocks considering non full attent + kv_cache_spec = self.model_runner.get_kv_cache_spec() + page_size = max(spec.page_size_bytes + for spec in kv_cache_spec.values()) + num_layers = len(kv_cache_spec) + num_attn_layers = 0 + for spec in kv_cache_spec.values(): + num_attn_layers += int(isinstance(spec, FullAttentionSpec)) + max_num_blocks = max_num_blocks * num_layers / num_attn_layers # for partition skip, we need dummy block slot. no_dummy_slots = 1 @@ -200,17 +261,16 @@ def determine_available_memory(self) -> int: num_gpu_blocks = min( int(max_num_blocks * self.cache_config.gpu_memory_utilization), max_required_num_blocks) + logger.info("max_num_blocks(%s), required_num_blocks(%s), num_blocks(%s)", + max_num_blocks, max_required_num_blocks, num_gpu_blocks) if npu_num_blocks := os.environ.get("VLLM_RBLN_NPU_NUM_BLOCKS"): num_gpu_blocks = int(npu_num_blocks) - kv_cache_spec = self.model_runner.get_kv_cache_spec() - num_layers = len(kv_cache_spec) - # TODO: Consider SWA hybrid models. - # Sync get_maximum_num_blocks with latest optimum-rbln. - page_size = max(spec.page_size_bytes - for spec in kv_cache_spec.values()) - return num_gpu_blocks * page_size * num_layers + # NOTE - consider SWA hybrid models + # SWA shares blocks with Full Attention, DO NOT count SWA layers + available_memory = num_gpu_blocks * page_size * num_attn_layers + return available_memory def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -220,12 +280,27 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: + if self.parallel_config.data_parallel_size > 1: + if envs.VLLM_RBLN_DP_IMPL == "padded_decode": + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # TODO: consider relaxing this constraint + assert max_num_batched_tokens % max_num_seqs == 0, \ + "max_num_batched_tokens must be divisible by max_num_seqs" + elif envs.VLLM_RBLN_DP_IMPL == "dummy_prefill": + raise ValueError("dummy_prefill is not supported in v1 worker" \ + "and will be deprecated in the future") + self.model_runner.prepare_dummy_run() + if (self.model_config.enforce_eager or not envs.VLLM_RBLN_COMPILE_MODEL or not envs.VLLM_RBLN_ENABLE_WARM_UP): logger.warning("skipping compile_or_warm_up_model") return self.model_runner.warm_up_model() + # after completing model warm up, enable RBLN performance tracker + self.model_runner._enable_performance_tracker() def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -293,7 +368,7 @@ def profile(self, is_start: bool = True): sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: - return + self.model_runner.dummy_run() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) @@ -330,7 +405,7 @@ def init_worker_distributed_environment( world_size = parallel_config.world_size # Set envs for RCCL - os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['LOCAL_RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) @@ -339,7 +414,7 @@ def init_worker_distributed_environment( world_size_across_dp = parallel_config.world_size_across_dp dp_rank = parallel_config.data_parallel_rank rank_across_dp = dp_rank * world_size - rank_across_dp += local_rank + rank_across_dp += rank logger.info("world_size_across_dp = %s, rank_across_dp = %s", world_size_across_dp, rank_across_dp) # consider across_dp diff --git a/vllm_rbln/worker/utils.py b/vllm_rbln/worker/utils.py index 446948501..775f545af 100644 --- a/vllm_rbln/worker/utils.py +++ b/vllm_rbln/worker/utils.py @@ -17,6 +17,7 @@ from typing import Optional from vllm.config import ModelConfig, ParallelConfig +from vllm.platforms import current_platform import vllm_rbln.rbln_envs as envs @@ -31,7 +32,7 @@ def get_maximum_num_blocks( buffer: Optional[int] = None, num_runtimes: int = 2, ) -> int: - # We are finding max_n_blocks(x) that satisfies the following equation: + # We are finding max_num_blocks(x) that satisfies the following equation: # available_dram - kernel_size - buffer # - num_layers * 2 * tensor_parallel_size @@ -72,17 +73,39 @@ def align_2MB(x: int) -> int: vocab_size = model_config.get_vocab_size() hidden_size = model_config.get_hidden_size() num_key_value_heads = model_config.get_num_kv_heads(parallel_config) - tensor_parallel_size = parallel_config.tensor_parallel_size * \ - envs.VLLM_RBLN_TP_SIZE + tp_size = parallel_config.tensor_parallel_size + dp_size = parallel_config.data_parallel_size + ep = parallel_config.enable_expert_parallel + rsd_size = envs.VLLM_RBLN_TP_SIZE # TODO(jongho): Update if target npu is REBEL. - ATOM_DRAM_NBYTES = 16 * 2**30 - ATOM_SYS_DRAM_NBYTES = 288 * 2**20 - available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - - ATOM_SYS_DRAM_NBYTES) - def check_oom(available_dram: int) -> None: - if available_dram <= 0: + device_name = current_platform.get_device_name().lower() + assert "rbln" in device_name + if "ca" in device_name: + # ATOM - RBLN-CA[xxx] + # ATOM DRAM - 16GB (single chip) + ATOM_DRAM_NBYTES = 16 * 2**30 + ATOM_SYS_DRAM_NBYTES = 288 * 2**20 + available_dram_bytes = rsd_size * (ATOM_DRAM_NBYTES - + ATOM_SYS_DRAM_NBYTES) + # ATOM - basic data type fp16 + default_bits_per_param = 16 + elif "cr" in device_name: + assert rsd_size == 1 + # REBEL - RBLN-CR[xxx] + # REBEL DRAM - 144GB (quad chips, chiplet) - system(4G) = 140GB + REBEL_DRAM_NBYTES = 144 * 2**30 + REBEL_SYS_DRAM_NBYTES = 4 * 2**30 + REBEL_DRAM_NBYTES -= REBEL_SYS_DRAM_NBYTES + available_dram_bytes = REBEL_DRAM_NBYTES + # FIXME(RBLN) - basic data type fp8 for REBEL, for now fp16 + default_bits_per_param = 16 + else: + assert False, "invalid RBLN architecture, candidates = [ATOM(ca), REBEL(cr)]" + + def check_oom(available_dram_bytes: int) -> None: + if available_dram_bytes <= 0: raise MemoryError("Insufficient DRAM during block calculation.") if kernel_size is None: @@ -90,35 +113,39 @@ def check_oom(available_dram: int) -> None: raise ValueError("`n_model_params` should be specified \ to estimate the kernel memory.") # Get estimated kernel size (approximated) + # kernel_size + # - QKV params - model parallel (tp) sharded + # - MLP or expert - model parallel (ep) sharded + # - word embedding- non sharded, not included into device, hidden_size * vocab_size + # - lm head - model parallel (tp) sharded, hidden_size * vocab_size lm_heads_params = align(vocab_size, 64) * hidden_size lm_heads_nbytes = (align_2MB( - lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * - tensor_parallel_size) - params = n_model_params - lm_heads_params - layer_nbytes = (align_2MB(params * nbits_per_param // 8 / num_layers / - tensor_parallel_size) * num_layers * - tensor_parallel_size) + lm_heads_params * default_bits_per_param // 8 / tp_size) * tp_size) + word_embedding_params = lm_heads_params + params = n_model_params - lm_heads_params - word_embedding_params + layer_nbytes = (align_2MB(params * nbits_per_param // 8 / num_layers) * num_layers) kernel_size = layer_nbytes + lm_heads_nbytes elif n_model_params is not None: raise ValueError( "Both `n_model_params` and `kernel_size` cannot be specified.") - available_dram -= kernel_size + available_dram_bytes -= kernel_size if buffer is None: # TODO: Accurate buffer estimation buffer_per_runtime_per_core = 2**28 # 256MB per runtime # 1 for prefill, 1 for decoder - buffer_per_core = buffer_per_runtime_per_core * num_runtimes - buffer = buffer_per_core * tensor_parallel_size - available_dram -= buffer - - check_oom(available_dram) - - b = kvcache_block_size * align(head_dim, 64) * math.ceil( - num_key_value_heads / tensor_parallel_size) * 2 - c = num_layers * 2 * tensor_parallel_size - k = available_dram / c - max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21)) - - return max_n_blocks + buffer = buffer_per_runtime_per_core * num_runtimes + available_dram_bytes -= buffer + + check_oom(available_dram_bytes) + + kv = 2 + kv_bytes = 2 + num_kv_heads = math.ceil(num_key_value_heads / rsd_size) * rsd_size + head_dim = align(head_dim, 64) + # [2(=kv), H(=num_kv_heads), 1, B(=block_size), D(=head_dim)] + kv_cache_block_bytes = kv * kvcache_block_size * head_dim * num_kv_heads * kv_bytes * num_layers + # for each k, v, max_num_blocks calculation is done + max_num_blocks = available_dram_bytes / kv_cache_block_bytes + return max_num_blocks