Skip to content

Conversation

@storyicon
Copy link

@storyicon storyicon commented Dec 15, 2025

Purpose

Add support for GPU tensors in the tensor_data function, enabling proper functionality of GPU-accelerated multimodal preprocessing.

Background

In our practical deployment, we enabled GPU-accelerated multimodal preprocessing by utilizing the following configurations, thereby moving tasks such as image and video preprocessing to the GPU. This significantly reduces CPU overhead in high-concurrency scenarios:

  • CLI argument: --mm-processor-kwargs '{"device": "cuda"}'
  • Model config: Setting "device": "cuda" in preprocessor_config.json

However, the current implementation of the tensor_data() function in vllm/v1/utils.py fails to handle tensors residing on the GPU, causing errors when GPU preprocessing is enabled.

Problem

The tensor_data() function is used for tensor serialization and hashing, particularly in multimodal input processing. The current implementation directly calls .numpy() on tensors:

return tensor.flatten().contiguous().view(torch.uint8).numpy().data

This fails for GPU tensors because PyTorch's .numpy() method only supports CPU tensors, raising:

 TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Solution

Add .cpu() call before .numpy() to handle both CPU and GPU tensors:

return tensor.flatten().contiguous().view(torch.uint8).cpu().numpy().data

Performance Impact:

  • CPU tensors: .cpu() is a no-op, no performance impact
  • GPU tensors: Necessary device-to-memory transfer (same as what would be needed anyway for serialization)

This change is critical for multimodal models when GPU preprocessing is enabled, as tensors may reside on GPU devices.

Test Plan

Unit Tests

import torch
from vllm.v1.utils import tensor_data

def test_tensor_data_cpu():
    tensor = torch.randn(10, 20)
    result = tensor_data(tensor)
    assert isinstance(result, memoryview)
    assert len(result) == tensor.numel() * tensor.element_size()
    print("CPU tensor test passed")

def test_tensor_data_gpu():
    if not torch.cuda.is_available():
        print("GPU not available, skipping GPU test")
        return
    tensor = torch.randn(10, 20).cuda()
    result = tensor_data(tensor)
    assert isinstance(result, memoryview)
    assert len(result) == tensor.numel() * tensor.element_size()
    print("GPU tensor test passed")

def test_tensor_data_various_dtypes():
    dtypes = [torch.float32, torch.float16, torch.int32, torch.int64]
    for dtype in dtypes:
        tensor = torch.randn(5, 5).to(dtype)
        result = tensor_data(tensor)
        assert isinstance(result, memoryview)
    print("Various dtypes test passed")

# Run tests
test_tensor_data_cpu()
test_tensor_data_gpu()
test_tensor_data_various_dtypes()

Integration Test

Test with actual multimodal preprocessing using GPU device:

# Start vLLM with GPU preprocessing
vllm serve <multimodal-model> \
  --mm-processor-kwargs '{"device": "cuda"}'
  
# Send multimodal inference request
# Should not raise TypeError during tensor serialization

Test Result

Before Fix

  • CPU tensors: Works correctly
  • GPU tensors: TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
  • Impact: GPU-accelerated preprocessing cannot be used

After Fix

  • CPU tensors: Works correctly (no performance regression)
  • GPU tensors: Works correctly with automatic device-to-memory transfer
  • Impact: GPU-accelerated preprocessing fully functional

Test Output:

CPU tensor test passed
GPU tensor test passed
Various dtypes test passed

Affected Code Paths

The tensor_data() function is called in:

  1. vllm/v1/serial_utils.py - Tensor encoding for serialization
  2. vllm/v1/core/kv_cache_utils.py - Block prompt embeddings hashing
  3. Various test files - Unit testing

All these code paths now work correctly with GPU tensors when GPU preprocessing is enabled.

Documentation Updates

Updated docstring in vllm/v1/utils.py:tensor_data() to clarify:

  • Function now supports both CPU and GPU tensors
  • Device-to-memory transfer behavior for GPU tensors
  • No-op behavior for CPU tensors

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Dec 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a critical issue where tensor_data() failed to process GPU tensors due to PyTorch's .numpy() method requiring CPU tensors. The solution, adding a .cpu() call before .numpy(), is correct and ensures compatibility with both CPU and GPU tensors, enabling GPU-accelerated multimodal preprocessing. The updated docstring clearly explains the behavior for both CPU and GPU tensors, including the necessary device-to-memory transfer for GPU tensors and the no-op for CPU tensors. The change is well-justified and includes a comprehensive test plan and results.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Dec 15, 2025

I actually have a PR for that: #22070

The reason why this hasn't been merged into main branch is that it goes against the design of GPU memory management being done inside the engine core. Let me rebase the PR to keep it up-to-date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants