diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 3b44d580db94..2b94362a808f 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -320,6 +320,9 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: model_inputs = { "input_ids": input_buffers.input_ids[:num_tokens], "positions": input_buffers.positions[:num_tokens], + # TODO: Pass intermediate_tensors for PP CUDA graph + # support (https://github.com/vllm-project/vllm/pull/35162). + "intermediate_tensors": None, **model_state.prepare_dummy_inputs(num_reqs, num_tokens), } model_output = model(**model_inputs) diff --git a/vllm/v1/worker/gpu/mm/mrope_utils.py b/vllm/v1/worker/gpu/mm/mrope_utils.py deleted file mode 100644 index 7e27f28bab93..000000000000 --- a/vllm/v1/worker/gpu/mm/mrope_utils.py +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.model_executor.models.interfaces import SupportsMRoPE -from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor - - -class MRopeState: - def __init__( - self, - max_num_reqs: int, - max_num_tokens: int, - max_model_len: int, - device: torch.device, - ): - self.max_num_reqs = max_num_reqs - self.max_num_tokens = max_num_tokens - self.max_model_len = max_model_len - self.device = device - - # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) - # wasting a lot of CPU memory. - self.prefill_mrope_positions = StagedWriteTensor( - (max_num_reqs * 3, max_model_len), - dtype=torch.int32, - device=device, - uva_instead_of_gpu=True, - ) - self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) - - # NOTE: `mrope_positions` is implemented with one additional dummy - # position on purpose to make it non-contiguous so that it can work - # with torch compile. - # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 - # NOTE: When M-RoPE is enabled, position ids are 3D regardless of - # the modality of inputs. For text-only inputs, each dimension has - # identical position IDs, making M-RoPE functionally equivalent to - # 1D-RoPE. - # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros( - (3, max_num_tokens + 1), dtype=torch.int64, device=device - ) - - def init_prefill_mrope_positions( - self, - req_idx: int, - mrope_model: SupportsMRoPE, - prefill_token_ids: list[int], - mm_features: list, - ) -> None: - prefill_mrope_positions, prefill_mrope_delta = ( - mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features) - ) - for i in range(3): - pos = prefill_mrope_positions[i].tolist() - self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos) - self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta - - def apply_staged_writes(self) -> None: - self.prefill_mrope_positions.apply_write() - self.prefill_mrope_delta.copy_to_uva() - - def prepare_mrope_positions( - self, - idx_mapping: torch.Tensor, - query_start_loc: torch.Tensor, - prefill_lens: torch.Tensor, - num_computed_tokens: torch.Tensor, - ) -> None: - num_reqs = idx_mapping.shape[0] - _prepare_mrope_positions_kernel[(num_reqs,)]( - self.mrope_positions, - self.mrope_positions.stride(0), - self.prefill_mrope_positions.gpu, - 3 * self.max_model_len, - self.max_model_len, - self.prefill_mrope_delta.gpu, - idx_mapping, - query_start_loc, - prefill_lens, - num_computed_tokens, - BLOCK_SIZE=1024, - ) - - -@triton.jit -def _prepare_mrope_positions_kernel( - mrope_positions_ptr, - mrope_positions_stride, - prefill_mrope_positions_ptr, - prefill_mrope_positions_stride0, - prefill_mrope_positions_stride1, - prefill_mrope_delta_ptr, - idx_mapping_ptr, - query_start_loc_ptr, - prefill_lens_ptr, - num_computed_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - - prefill_len = tl.load(prefill_lens_ptr + req_state_idx) - num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) - is_prefill = num_computed < prefill_len - - query_start = tl.load(query_start_loc_ptr + batch_idx) - query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - query_len = query_end - query_start - - mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx) - for i in range(0, query_len, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - mask = block < query_len - orig_pos = num_computed + block - - for j in tl.static_range(3): - if is_prefill: - # Read from pre-computed M-RoPE positions. - pos = tl.load( - prefill_mrope_positions_ptr - + req_state_idx * prefill_mrope_positions_stride0 - + j * prefill_mrope_positions_stride1 - + orig_pos, - mask=mask, - ) - else: - # Apply M-RoPE delta. - pos = orig_pos + mrope_delta - tl.store( - mrope_positions_ptr + j * mrope_positions_stride + query_start + block, - pos, - mask=mask, - ) diff --git a/vllm/v1/worker/gpu/mm/rope.py b/vllm/v1/worker/gpu/mm/rope.py new file mode 100644 index 000000000000..712f58af578f --- /dev/null +++ b/vllm/v1/worker/gpu/mm/rope.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + +import torch +import torch.nn as nn + +from vllm.config import ModelConfig +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor + + +class RopeState: + """Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). + + M-RoPE: 3 dims, uses position delta for decode. + XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims). + + NOTE: `positions` is implemented with one additional dummy position on + purpose to make it non-contiguous so that it can work with torch compile. + See detailed explanation in + https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + NOTE: When M-RoPE is enabled, position ids are 3D regardless of the + modality of inputs. For text-only inputs, each dimension has identical + position IDs, making M-RoPE functionally equivalent to 1D-RoPE. + See page 5 of https://arxiv.org/abs/2409.12191 + """ + + def __init__( + self, + num_dims: int, + has_delta: bool, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, + ): + self.num_dims = num_dims + self.has_delta = has_delta + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.max_model_len = max_model_len + self.device = device + + # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) + # wasting a lot of CPU memory. + self.prefill_positions = StagedWriteTensor( + (max_num_reqs * num_dims, max_model_len), + dtype=torch.int32, + device=device, + uva_instead_of_gpu=True, + ) + self.positions = torch.zeros( + (num_dims, max_num_tokens + 1), dtype=torch.int64, device=device + ) + + # Delta is non-zero for M-RoPE, always 0 for XD-RoPE. + self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + + def init_prefill_positions( + self, + req_idx: int, + model: nn.Module, + prefill_token_ids: list[int], + mm_features: list, + ) -> None: + if self.has_delta: + mrope_model = cast(SupportsMRoPE, model) + prefill_positions, delta = mrope_model.get_mrope_input_positions( + prefill_token_ids, mm_features + ) + self.prefill_delta.np[req_idx] = delta + else: + xdrope_model = cast(SupportsXDRoPE, model) + prefill_positions = xdrope_model.get_xdrope_input_positions( + prefill_token_ids, mm_features + ) + + for i in range(self.num_dims): + pos = prefill_positions[i].tolist() + self.prefill_positions.stage_write(self.num_dims * req_idx + i, 0, pos) + + def apply_staged_writes(self) -> None: + self.prefill_positions.apply_write() + if self.has_delta: + self.prefill_delta.copy_to_uva() + + def get_positions(self, num_tokens: int) -> torch.Tensor: + return self.positions[:, :num_tokens] + + def prepare_positions( + self, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_lens: torch.Tensor, + num_computed_tokens: torch.Tensor, + ) -> None: + num_reqs = idx_mapping.shape[0] + _prepare_rope_positions_kernel[(num_reqs,)]( + self.positions, + self.positions.stride(0), + self.prefill_positions.gpu, + self.num_dims * self.max_model_len, + self.max_model_len, + self.prefill_delta.gpu, + idx_mapping, + query_start_loc, + prefill_lens, + num_computed_tokens, + BLOCK_SIZE=1024, + NUM_DIMS=self.num_dims, + ) + + +def get_rope_state( + model_config: ModelConfig, + model: nn.Module, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, +) -> RopeState | None: + """Create a RopeState if the model uses multi-dimensional RoPE.""" + if model_config.uses_mrope: + assert isinstance(model, SupportsMRoPE) + return RopeState( + num_dims=3, + has_delta=True, + max_num_reqs=max_num_reqs, + max_num_tokens=max_num_tokens, + max_model_len=max_model_len, + device=device, + ) + if model_config.uses_xdrope_dim > 0: + assert isinstance(model, SupportsXDRoPE) + return RopeState( + num_dims=model_config.uses_xdrope_dim, + has_delta=False, + max_num_reqs=max_num_reqs, + max_num_tokens=max_num_tokens, + max_model_len=max_model_len, + device=device, + ) + return None + + +@triton.jit +def _prepare_rope_positions_kernel( + positions_ptr, + positions_stride, + prefill_positions_ptr, + prefill_positions_stride0, + prefill_positions_stride1, + prefill_delta_ptr, + idx_mapping_ptr, + query_start_loc_ptr, + prefill_lens_ptr, + num_computed_tokens_ptr, + BLOCK_SIZE: tl.constexpr, + NUM_DIMS: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + prefill_len = tl.load(prefill_lens_ptr + req_state_idx) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + is_prefill = num_computed < prefill_len + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + delta = tl.load(prefill_delta_ptr + req_state_idx) + + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + orig_pos = num_computed + block + + for j in tl.static_range(NUM_DIMS): + if is_prefill: + pos = tl.load( + prefill_positions_ptr + + req_state_idx * prefill_positions_stride0 + + j * prefill_positions_stride1 + + orig_pos, + mask=mask, + ) + else: + pos = orig_pos + delta + tl.store( + positions_ptr + j * positions_stride + query_start + block, + pos, + mask=mask, + ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 7268b8ac191f..57f170b59000 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -992,6 +992,7 @@ def execute_model( "input_ids": input_batch.input_ids, "positions": input_batch.positions, "inputs_embeds": inputs_embeds, + "intermediate_tensors": intermediate_tensors, # NOTE: Values returned by `prepare_inputs` will override the default # values above. **self.model_state.prepare_inputs(input_batch, self.req_states), @@ -1000,7 +1001,7 @@ def execute_model( # Update for non-first PP ranks. model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = None - model_inputs["intermediate_tensors"] = intermediate_tensors + assert intermediate_tensors is not None # Run model. if batch_desc.cg_mode == CUDAGraphMode.FULL: diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 783d225c4a90..104e4c1948b5 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner -from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState +from vllm.v1.worker.gpu.mm.rope import get_rope_state from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup @@ -52,29 +52,28 @@ def __init__( device=self.device, ) - self.uses_mrope = self.model_config.uses_mrope - if self.uses_mrope: - self.mrope_state = MRopeState( - max_num_reqs=self.max_num_reqs, - max_num_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - ) + self.rope_state = get_rope_state( + self.model_config, + model, + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + ) def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: - if self.uses_mrope: - # Pre-compute M-RoPE positions for prefill. + if self.rope_state is not None: assert new_req_data.prefill_token_ids is not None - self.mrope_state.init_prefill_mrope_positions( + self.rope_state.init_prefill_positions( req_index, - self.model, # type: ignore + self.model, new_req_data.prefill_token_ids, mm_features=new_req_data.mm_features, ) def apply_staged_writes(self) -> None: - if self.uses_mrope: - self.mrope_state.apply_staged_writes() + if self.rope_state is not None: + self.rope_state.apply_staged_writes() def get_mm_embeddings( self, @@ -109,31 +108,26 @@ def get_mm_embeddings( def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState - ) -> dict[str, Any]: - if not self.uses_mrope: - # Common case (1D positions). - return {} + ) -> dict[str, torch.Tensor | None]: + if self.rope_state is None: + return {} # Common case (1D positions). - # Prepare M-RoPE positions. - self.mrope_state.prepare_mrope_positions( + self.rope_state.prepare_positions( input_batch.idx_mapping, input_batch.query_start_loc, req_states.prefill_len.gpu, req_states.num_computed_tokens.gpu, ) - mrope_positions = self.mrope_state.mrope_positions[ - :, : input_batch.num_tokens_after_padding - ] - return {"positions": mrope_positions} + positions = self.rope_state.get_positions(input_batch.num_tokens_after_padding) + return {"positions": positions} def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: model_inputs = {} if self.supports_mm_inputs: inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens] model_inputs["inputs_embeds"] = inputs_embeds - if self.uses_mrope: - mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] - model_inputs["positions"] = mrope_positions + if self.rope_state is not None: + model_inputs["positions"] = self.rope_state.get_positions(num_tokens) return model_inputs def prepare_attn(