-
-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[Model Runner V2] Add Support for XD-RoPE #36817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
85f2b8b
9d2fe4a
68e5caa
9db9d80
f4b6f6a
49c7517
9d7fe65
32155bb
7d790f9
19afa5f
4f3b6fb
728f40c
3dbfe9f
1ba5404
1209251
8215e71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import torch | ||
|
|
||
| from vllm.model_executor.models.interfaces import SupportsXDRoPE | ||
| from vllm.triton_utils import tl, triton | ||
| from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor | ||
|
|
||
|
|
||
| class XDRopeState: | ||
|
||
| def __init__( | ||
| self, | ||
| uses_xdrope_dim: int, | ||
| max_num_reqs: int, | ||
| max_num_tokens: int, | ||
| max_model_len: int, | ||
| device: torch.device, | ||
| ): | ||
| self.uses_xdrope_dim = uses_xdrope_dim | ||
| self.max_num_reqs = max_num_reqs | ||
| self.max_num_tokens = max_num_tokens | ||
| self.max_model_len = max_model_len | ||
| self.device = device | ||
|
|
||
| # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) | ||
| # wasting a lot of CPU memory. | ||
| self.prefill_xdrope_positions = StagedWriteTensor( | ||
| (max_num_reqs * uses_xdrope_dim, max_model_len), | ||
| dtype=torch.int32, | ||
| device=device, | ||
| uva_instead_of_gpu=True, | ||
| ) | ||
| self.xdrope_positions = torch.zeros( | ||
| (uses_xdrope_dim, max_num_tokens + 1), dtype=torch.int64, device=device | ||
| ) | ||
|
|
||
| def init_prefill_xdrope_positions( | ||
| self, | ||
| req_idx: int, | ||
| xdrope_model: SupportsXDRoPE, | ||
| prefill_token_ids: list[int], | ||
| mm_features: list, | ||
| ) -> None: | ||
| prefill_xdrope_positions = xdrope_model.get_xdrope_input_positions( | ||
| prefill_token_ids, mm_features | ||
| ) | ||
| for i in range(self.uses_xdrope_dim): | ||
| pos = prefill_xdrope_positions[i].tolist() | ||
| self.prefill_xdrope_positions.stage_write( | ||
| self.uses_xdrope_dim * req_idx + i, 0, pos | ||
| ) | ||
|
|
||
| def apply_staged_writes(self) -> None: | ||
| self.prefill_xdrope_positions.apply_write() | ||
|
|
||
| def prepare_xdrope_positions( | ||
| self, | ||
| idx_mapping: torch.Tensor, | ||
| query_start_loc: torch.Tensor, | ||
| prefill_lens: torch.Tensor, | ||
| num_computed_tokens: torch.Tensor, | ||
| ) -> None: | ||
| num_reqs = idx_mapping.shape[0] | ||
| _prepare_xdrope_positions_kernel[(num_reqs,)]( | ||
| self.xdrope_positions, | ||
| self.xdrope_positions.stride(0), | ||
| self.prefill_xdrope_positions.gpu, | ||
| self.uses_xdrope_dim * self.max_model_len, | ||
| self.max_model_len, | ||
| idx_mapping, | ||
| query_start_loc, | ||
| prefill_lens, | ||
| num_computed_tokens, | ||
| BLOCK_SIZE=1024, | ||
| USES_XDROPE_DIM=self.uses_xdrope_dim, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _prepare_xdrope_positions_kernel( | ||
| xdrope_positions_ptr, | ||
| xdrope_positions_stride, | ||
| prefill_xdrope_positions_ptr, | ||
| prefill_xdrope_positions_stride0, | ||
| prefill_xdrope_positions_stride1, | ||
| idx_mapping_ptr, | ||
| query_start_loc_ptr, | ||
| prefill_lens_ptr, | ||
| num_computed_tokens_ptr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| USES_XDROPE_DIM: tl.constexpr, | ||
| ): | ||
| batch_idx = tl.program_id(0) | ||
| req_state_idx = tl.load(idx_mapping_ptr + batch_idx) | ||
|
|
||
| prefill_len = tl.load(prefill_lens_ptr + req_state_idx) | ||
| num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) | ||
| is_prefill = num_computed < prefill_len | ||
|
|
||
| query_start = tl.load(query_start_loc_ptr + batch_idx) | ||
| query_end = tl.load(query_start_loc_ptr + batch_idx + 1) | ||
| query_len = query_end - query_start | ||
|
|
||
| for i in range(0, query_len, BLOCK_SIZE): | ||
| block = i + tl.arange(0, BLOCK_SIZE) | ||
| mask = block < query_len | ||
| orig_pos = num_computed + block | ||
|
|
||
| for j in tl.static_range(USES_XDROPE_DIM): | ||
| if is_prefill: | ||
| # Read from pre-computed XD-RoPE positions. | ||
| pos = tl.load( | ||
| prefill_xdrope_positions_ptr | ||
| + req_state_idx * prefill_xdrope_positions_stride0 | ||
| + j * prefill_xdrope_positions_stride1 | ||
| + orig_pos, | ||
| mask=mask, | ||
| ) | ||
| else: | ||
| pos = orig_pos | ||
| tl.store( | ||
| xdrope_positions_ptr | ||
| + j * xdrope_positions_stride | ||
| + query_start | ||
| + block, | ||
| pos, | ||
| mask=mask, | ||
| ) | ||
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -973,6 +973,7 @@ def execute_model( | |||||||
| "input_ids": input_batch.input_ids, | ||||||||
| "positions": input_batch.positions, | ||||||||
| "inputs_embeds": inputs_embeds, | ||||||||
| "intermediate_tensors": intermediate_tensors, | ||||||||
| # NOTE: Values returned by `prepare_inputs` will override the default | ||||||||
| # values above. | ||||||||
| **self.model_state.prepare_inputs(input_batch, self.req_states), | ||||||||
|
|
@@ -981,7 +982,7 @@ def execute_model( | |||||||
| # Update for non-first PP ranks. | ||||||||
| model_inputs["input_ids"] = None | ||||||||
| model_inputs["inputs_embeds"] = None | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could consider adding this for clarity
Suggested change
|
||||||||
| model_inputs["intermediate_tensors"] = intermediate_tensors | ||||||||
| assert intermediate_tensors is not None | ||||||||
|
|
||||||||
| # Run model. | ||||||||
| if batch_desc.cg_mode == CUDAGraphMode.FULL: | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache | ||
| from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner | ||
| from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState | ||
| from vllm.v1.worker.gpu.mm.xdrope_utils import XDRopeState | ||
| from vllm.v1.worker.gpu.model_states.interface import ModelState | ||
| from vllm.v1.worker.gpu.states import RequestState | ||
| from vllm.v1.worker.utils import AttentionGroup | ||
|
|
@@ -60,6 +61,15 @@ def __init__( | |
| max_model_len=self.max_model_len, | ||
| device=self.device, | ||
| ) | ||
| self.xdrope_state: XDRopeState | None = None | ||
| if self.model_config.uses_xdrope_dim > 0: | ||
| self.xdrope_state = XDRopeState( | ||
| uses_xdrope_dim=self.model_config.uses_xdrope_dim, | ||
| max_num_reqs=self.max_num_reqs, | ||
| max_num_tokens=self.max_num_tokens, | ||
| max_model_len=self.max_model_len, | ||
| device=self.device, | ||
| ) | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: | ||
| if self.uses_mrope: | ||
|
|
@@ -71,10 +81,21 @@ def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: | |
| new_req_data.prefill_token_ids, | ||
| mm_features=new_req_data.mm_features, | ||
| ) | ||
| elif self.xdrope_state is not None: | ||
| # Pre-compute XD-RoPE positions for prefill. | ||
| assert new_req_data.prefill_token_ids is not None | ||
| self.xdrope_state.init_prefill_xdrope_positions( | ||
| req_index, | ||
| self.model, # type: ignore | ||
| new_req_data.prefill_token_ids, | ||
| mm_features=new_req_data.mm_features, | ||
| ) | ||
|
|
||
| def apply_staged_writes(self) -> None: | ||
| if self.uses_mrope: | ||
| self.mrope_state.apply_staged_writes() | ||
| elif self.xdrope_state is not None: | ||
| self.xdrope_state.apply_staged_writes() | ||
|
|
||
| def get_mm_embeddings( | ||
| self, | ||
|
|
@@ -109,22 +130,35 @@ def get_mm_embeddings( | |
|
|
||
| def prepare_inputs( | ||
| self, input_batch: InputBatch, req_states: RequestState | ||
| ) -> dict[str, Any]: | ||
| if not self.uses_mrope: | ||
| # Common case (1D positions). | ||
| return {} | ||
| ) -> dict[str, torch.Tensor | None]: | ||
| if not self.uses_mrope and self.xdrope_state is None: | ||
| return {} # Common case (1D positions). | ||
|
|
||
| if self.uses_mrope: | ||
|
||
| # Prepare M-RoPE positions. | ||
| self.mrope_state.prepare_mrope_positions( | ||
| input_batch.idx_mapping, | ||
| input_batch.query_start_loc, | ||
| req_states.prefill_len.gpu, | ||
| req_states.num_computed_tokens.gpu, | ||
| ) | ||
| mrope_positions = self.mrope_state.mrope_positions[ | ||
| :, : input_batch.num_tokens_after_padding | ||
| ] | ||
| return {"positions": mrope_positions} | ||
|
|
||
| # Prepare M-RoPE positions. | ||
| self.mrope_state.prepare_mrope_positions( | ||
| # Prepare XD-RoPE positions. | ||
| assert self.xdrope_state is not None | ||
| self.xdrope_state.prepare_xdrope_positions( | ||
| input_batch.idx_mapping, | ||
| input_batch.query_start_loc, | ||
| req_states.prefill_len.gpu, | ||
| req_states.num_computed_tokens.gpu, | ||
| ) | ||
| mrope_positions = self.mrope_state.mrope_positions[ | ||
| xdrope_positions = self.xdrope_state.xdrope_positions[ | ||
| :, : input_batch.num_tokens_after_padding | ||
| ] | ||
| return {"positions": mrope_positions} | ||
| return {"positions": xdrope_positions} | ||
|
|
||
| def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: | ||
| model_inputs = {} | ||
|
|
@@ -134,6 +168,9 @@ def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any] | |
| if self.uses_mrope: | ||
| mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] | ||
| model_inputs["positions"] = mrope_positions | ||
| elif self.xdrope_state is not None: | ||
| xdrope_positions = self.xdrope_state.xdrope_positions[:, :num_tokens] | ||
| model_inputs["positions"] = xdrope_positions | ||
| return model_inputs | ||
|
|
||
| def prepare_attn( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.