Skip to content

[Model Runner V2] Add Support for XD-RoPE#36817

Open
santiramos27 wants to merge 9 commits intovllm-project:mainfrom
santiramos27:santinor/mrv2-xdrope
Open

[Model Runner V2] Add Support for XD-RoPE#36817
santiramos27 wants to merge 9 commits intovllm-project:mainfrom
santiramos27:santinor/mrv2-xdrope

Conversation

@santiramos27
Copy link

@santiramos27 santiramos27 commented Mar 11, 2026

Purpose

Add XD-RoPE support to Model Runner V2, enabling models like HunyuanVL (tencent/HunyuanOCR). Mostly follows the existing M-RoPE support PR, with the main differences being:

  • Variable dimension count: M-RoPE is always 3D (T/H/W). XD-RoPE supports 3 or 4 dimensions (W/H/T or P/W/H/T), parameterized via uses_xdrope_dim from model config. The Triton kernel uses USES_XDROPE_DIM as a tl.constexpr instead of hardcoding tl.static_range(3)
  • No position delta: M-RoPE stores a per-request mrope_position_delta used during decode to offset positions. XD-RoPE doesn't need this — during decode, all dimensions use the same sequential position

Also fixes intermediate_tensors not being passed in model_inputs for first PP rank and during CUDA graph capture, which caused TypeError on models (like HunyuanVL) whose forward() signature requires it as a positional argument.

Test Plan

  • Server starts successfully with VLLM_USE_V2_MODEL_RUNNER=1 and tencent/HunyuanOCR (both eager and CUDA graph)

Test Result

Identical OCRBench result (856) on V1 vs V2 model runner.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot added the v1 label Mar 11, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for XD-RoPE. The changes include a new XDRopeState class with a Triton kernel to manage XD-RoPE positions, and integration into the DefaultModelState. The core logic seems correct, but I've identified an opportunity to improve the clarity and maintainability of the new Triton kernel by refactoring its indexing logic to be more conventional. The rest of the changes for integrating XD-RoPE support are well-structured.

Comment on lines +56 to +128
def prepare_xdrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_xdrope_positions_kernel[(num_reqs,)](
self.xdrope_positions,
self.xdrope_positions.stride(0),
self.prefill_xdrope_positions.gpu,
self.uses_xdrope_dim * self.max_model_len,
self.max_model_len,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
USES_XDROPE_DIM=self.uses_xdrope_dim,
)


@triton.jit
def _prepare_xdrope_positions_kernel(
xdrope_positions_ptr,
xdrope_positions_stride,
prefill_xdrope_positions_ptr,
prefill_xdrope_positions_stride0,
prefill_xdrope_positions_stride1,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
USES_XDROPE_DIM: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len

query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start

for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block

for j in tl.static_range(USES_XDROPE_DIM):
if is_prefill:
# Read from pre-computed XD-RoPE positions.
pos = tl.load(
prefill_xdrope_positions_ptr
+ req_state_idx * prefill_xdrope_positions_stride0
+ j * prefill_xdrope_positions_stride1
+ orig_pos,
mask=mask,
)
else:
pos = orig_pos
tl.store(
xdrope_positions_ptr
+ j * xdrope_positions_stride
+ query_start
+ block,
pos,
mask=mask,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _prepare_xdrope_positions_kernel and its call site in prepare_xdrope_positions is functionally correct but hard to understand due to unconventional argument passing for tensor indexing. The arguments prefill_xdrope_positions_stride0 and prefill_xdrope_positions_stride1 are not standard tensor strides, making the address calculation logic within the kernel confusing.

For better readability and maintainability, I suggest refactoring to use standard stride-based indexing. This involves passing the actual stride of the prefill_xdrope_positions tensor and adjusting the kernel to calculate the row index explicitly. This change does not alter the logic but makes the code much clearer and aligned with common practices for writing Triton kernels.

    def prepare_xdrope_positions(
        self,
        idx_mapping: torch.Tensor,
        query_start_loc: torch.Tensor,
        prefill_lens: torch.Tensor,
        num_computed_tokens: torch.Tensor,
    ) -> None:
        num_reqs = idx_mapping.shape[0]
        prefill_positions = self.prefill_xdrope_positions.gpu
        _prepare_xdrope_positions_kernel[(num_reqs,)](
            self.xdrope_positions,
            self.xdrope_positions.stride(0),
            prefill_positions,
            prefill_positions.stride(0),
            idx_mapping,
            query_start_loc,
            prefill_lens,
            num_computed_tokens,
            BLOCK_SIZE=1024,
            USES_XDROPE_DIM=self.uses_xdrope_dim,
        )


@triton.jit
def _prepare_xdrope_positions_kernel(
    xdrope_positions_ptr,
    xdrope_positions_stride,
    prefill_xdrope_positions_ptr,
    prefill_xdrope_positions_stride0,
    idx_mapping_ptr,
    query_start_loc_ptr,
    prefill_lens_ptr,
    num_computed_tokens_ptr,
    BLOCK_SIZE: tl.constexpr,
    USES_XDROPE_DIM: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
    is_prefill = num_computed < prefill_len

    query_start = tl.load(query_start_loc_ptr + batch_idx)
    query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
    query_len = query_end - query_start

    for i in range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
        orig_pos = num_computed + block

        for j in tl.static_range(USES_XDROPE_DIM):
            if is_prefill:
                # Read from pre-computed XD-RoPE positions.
                row_idx = req_state_idx * USES_XDROPE_DIM + j
                pos = tl.load(
                    prefill_xdrope_positions_ptr
                    + row_idx * prefill_xdrope_positions_stride0
                    + orig_pos,
                    mask=mask,
                )
            else:
                pos = orig_pos
            tl.store(
                xdrope_positions_ptr
                + j * xdrope_positions_stride
                + query_start
                + block,
                pos,
                mask=mask,
            )

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot added the nvidia label Mar 11, 2026
@santiramos27 santiramos27 marked this pull request as ready for review March 11, 2026 20:56
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27!

:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
if self.uses_mrope:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be good to keep common case at the top, i.e.

if not self.uses_mrope and not self.uses_xdrope:
    # Common case (1D positions).
    return {}

if self.uses_mrope:
    # ...
    return ...

# xdrope logic
return ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I intentionally put the early exit for the common case at the top. It'd be nice to keep it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early exit makes sense. updated.

max_model_len=self.max_model_len,
device=self.device,
)
self.uses_xdrope_dim = self.model_config.uses_xdrope_dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a bool

Suggested change
self.uses_xdrope_dim = self.model_config.uses_xdrope_dim
self.uses_xdrope = self.model_config.uses_xdrope_dim > 0

I actually like using self.xdrope_state = XDRopeState | None for this rather than separate bool, but I don't think @WoosukKwon likes that :)

Or, I feel it might be better/clearer to instead have something like a pe_type enum (MROPE, XDROPE or None)

Copy link
Collaborator

@WoosukKwon WoosukKwon Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I'm also ok with if self.xdrope_state is not None.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change to using if self.xdrope_state is not None. From what I can tell, these two PE variants are the only ones getting special treatment in model runner. Should we wait to use enums once we have more PE's to support? It already feels like a good amount of duplication in mrope/xdrope that can probably get consolidated once we have more of them to worry about?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I just thought it's maybe clearer to have a single enum than two bools which can't both be true

@@ -960,7 +961,6 @@ def execute_model(
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider adding this for clarity

Suggested change
model_inputs["inputs_embeds"] = None
model_inputs["inputs_embeds"] = None
assert intermediate_tensors is not None

@mergify
Copy link

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @santiramos27.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot removed the needs-rebase label Mar 11, 2026
@mergify
Copy link

mergify bot commented Mar 11, 2026

Hi @santiramos27, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@santiramos27 santiramos27 requested a review from njhill March 11, 2026 22:37
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor


class XDRopeState:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code now ... it would simplify things to have a shared RopeState interface with init_prefill_positions, prepare_positions and get_positions(num_tokens) methods

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants