Skip to content

[XPU] Support cpu kv offloading on XPU platform#36423

Open
chaojun-zhang wants to merge 1 commit intovllm-project:mainfrom
chaojun-zhang:cpu_offload
Open

[XPU] Support cpu kv offloading on XPU platform#36423
chaojun-zhang wants to merge 1 commit intovllm-project:mainfrom
chaojun-zhang:cpu_offload

Conversation

@chaojun-zhang
Copy link
Copy Markdown
Contributor

@chaojun-zhang chaojun-zhang commented Mar 9, 2026

Purpose

Support CPU KV offloading with XPU swap_blocks kernel on XPU platform

Test Plan

pytest -s -v tests/v1/kv_offload
pytest -s -v tests/v1/kv_connector/unit/offloading_connector/test_worker.py

Test Result

Qwen-0.6B latency configuration:
Command: vllm bench latency --model=meta-llama/Llama-3.1-8B -tp 2

Median latency

Configuration eager compile
--kv_transfer_config={"kv_connector": "OffloadingConnector", "kv_role": "kv_both", "kv_connector_extra_config": {"cpu_bytes_to_use": 524288000, "block_size": 64}} 5.58920s 5.57366s
default 5.20960s 5.17703s

lm_eval:

  1. with kv offloading
from lm_eval import evaluator
from lm_eval.models.vllm_causallms import VLLM
from vllm.config import KVEventsConfig, KVTransferConfig

if __name__ == '__main__':
    kv_transfer_config = KVTransferConfig(
        kv_connector="OffloadingConnector",
        kv_role="kv_both",
        kv_connector_extra_config={
            "cpu_bytes_to_use": 500 << 20,
            "block_size": 64,
        },
    )

    model = VLLM(
        pretrained="meta-llama/Llama-3.1-8B",
        dtype="bfloat16",
        tensor_parallel_size=2,
        add_bos_token=True,
        trust_remote_code=True,
        kv_transfer_config=kv_transfer_config,
    )

    results = evaluator.simple_evaluate(
        model=model,
        tasks=["gsm8k"],
        num_fewshot=5,
        limit=250,
        batch_size=20,
    )
    print(results["results"])
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.536 ± 0.0316
3 strict-match 5 exact_match 0.536 ± 0.0316
  1. without kv offloading
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.528 ± 0.0316
3 strict-match 5 exact_match 0.528 ± 0.0316

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Mar 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for CPU KV offloading on the XPU platform. The changes primarily involve making the existing CUDA offloading logic and tests device-agnostic by using current_platform.device_type and conditional logic for XPU-specific calls. A new CpuXpuOffloadingHandlers class is introduced, which cleverly reuses the CpuGpuOffloadingHandlers logic by monkey-patching torch.cuda functions with their torch.xpu counterparts within a context manager.

However, I've identified a critical issue in the implementation of the _torch_cuda_wrapper context manager in the new vllm/v1/kv_offload/worker/cpu_xpu.py file. The monkey-patching of torch.cuda attributes is not reverted in the finally block. This can lead to persistent, unintended side effects across the application, potentially causing hard-to-debug issues in other parts of the code that expect the original torch.cuda behavior. I've provided a code suggestion to fix this by properly restoring the original attributes.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chaojun-zhang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 13, 2026

Documentation preview: https://vllm--36423.org.readthedocs.build/en/36423/

@mergify mergify bot added documentation Improvements or additions to documentation and removed needs-rebase labels Mar 13, 2026
@chaojun-zhang chaojun-zhang force-pushed the cpu_offload branch 3 times, most recently from c8a24ba to c58e02f Compare March 13, 2026 02:12
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chaojun-zhang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 13, 2026
@mergify mergify bot added intel-gpu Related to Intel GPU and removed needs-rebase labels Mar 29, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chaojun-zhang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 8, 2026

Hi @chaojun-zhang, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@chaojun-zhang chaojun-zhang force-pushed the cpu_offload branch 3 times, most recently from dca4c3d to f0403d4 Compare April 8, 2026 12:27
@zhenwei-intel
Copy link
Copy Markdown
Contributor

There are some newly added tests related to KV offloading. https://github.com/vllm-project/vllm/tree/main/tests/v1/kv_connector/unit/offloading_connector
Can we add them?

Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
@chaojun-zhang
Copy link
Copy Markdown
Contributor Author

this pr depends on vllm-project/vllm-xpu-kernels#265

@chaojun-zhang
Copy link
Copy Markdown
Contributor Author

There are some newly added tests related to KV offloading. https://github.com/vllm-project/vllm/tree/main/tests/v1/kv_connector/unit/offloading_connector Can we add them?

added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation intel-gpu Related to Intel GPU kv-connector v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants