Skip to content

perf(multimodal): avoid feature hashing for pad values#1030

Open
hashkanna wants to merge 4 commits into
sgl-project:mainfrom
hashkanna:kanna/mm-pad-value-content-hash
Open

perf(multimodal): avoid feature hashing for pad values#1030
hashkanna wants to merge 4 commits into
sgl-project:mainfrom
hashkanna:kanna/mm-pad-value-content-hash

Conversation

@hashkanna
Copy link
Copy Markdown

Motivation

Fixes #1029.

Single-image VLM requests currently derive MultimodalDataItem.pad_value by hashing the processed feature tensor when no precomputed hash is present. In a v6e-4 profile of Qwen/Qwen2.5-VL-7B-Instruct, this tokenizer-side feature hashing showed up as a measurable host-side request cost:

Event Calls Mean Total
MultimodalDataItem.set_pad_value 65 3412.899us 221.838ms
hash_feature 65 3316.494us 215.572ms

The tokenizer already has access to the raw image/video/audio payload bytes before feature preprocessing. This PR carries a stable content hash into MultimodalDataItem.hash, allowing set_pad_value() to avoid walking large processed feature tensors on the request path.

Related inspiration: Modal's multimodal inference optimization write-up discusses the same general principle of carrying stable media-content identity through preprocessing instead of repeatedly hashing large processed tensors on the request path: https://modal.com/blog/boosting-multimodal-inference-performance-by-greater-than-10-with-a-single-python-dictionary

Modifications

  • Add an internal loaded-payload wrapper that carries both decoded multimodal data and a stable payload hash.
  • Compute payload hashes for image, video, and audio inputs while raw bytes are available.
  • Include modality and relevant preprocessing metadata in hashes where needed so cache keys remain distinct.
  • Set MultimodalDataItem.hash before set_pad_value() runs.
  • Keep hash_feature as the fallback for callers without a precomputed hash.
  • Add focused unit tests for precomputed hash behavior, fallback behavior, dict round-trip behavior, and image payload hash stability.

Accuracy Tests

No model output changes are expected. This changes the way multimodal pad values are derived for cache-key differentiation, while preserving the existing feature-hash fallback.

Local validation:

  • uv run --project python python -m unittest sgl_jax.test.multimodal.test_multimodal_pad_value_hash
  • python3 -m py_compile python/sgl_jax/srt/multimodal/manager/multimodal_tokenizer.py python/sgl_jax/test/multimodal/test_multimodal_pad_value_hash.py
  • ruff check python/sgl_jax/srt/multimodal/manager/multimodal_tokenizer.py python/sgl_jax/test/multimodal/test_multimodal_pad_value_hash.py
  • pre-commit run --all-files --show-diff-on-failure

TPU validation:

  • Focused unittest passed on v6e-4 before the profiling confirmation run.

Benchmarking and Profiling

Profiling confirmation on v6e-4 with Qwen/Qwen2.5-VL-7B-Instruct:

Metric Baseline: feature-tensor hash Optimized: payload hash
Measured requests 64/64 HTTP 200 64/64 HTTP 200
hash_feature calls 65 0
hash_feature mean 3316.494us N/A
hash_feature total 215.572ms N/A
MultimodalDataItem.set_pad_value calls 65 65
MultimodalDataItem.set_pad_value mean 3412.899us 31.320us
MultimodalDataItem.set_pad_value total 221.838ms 2.036ms
Request latency median 0.3174s 0.3115s
Request latency p95 0.3247s 0.3180s

The optimized run had one request-level outlier at 8.32436s; median and p95 remained stable.

Checklist

  • Please use English, otherwise it will be closed.
  • The purpose of the PR, or link existing issues this PR will resolve.
  • The test plan, such as providing test command.
  • (Optional) The necessary documentation update. N/A.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request optimizes multimodal request processing by shifting from feature-tensor hashing to stable payload hashing. By calculating hashes while raw input bytes are available, the system significantly reduces host-side request latency and avoids redundant, computationally expensive operations on large feature tensors. The existing feature-hash mechanism is preserved as a fallback, ensuring backward compatibility while providing a measurable performance boost.

Highlights

  • Performance Optimization: Implemented precomputed payload hashing for multimodal inputs to avoid expensive feature tensor hashing during request processing.
  • Payload Handling: Introduced a new internal wrapper, _LoadedMultimodalPayload, to carry both the processed data and its stable content hash.
  • Testing: Added comprehensive unit tests to verify hash stability, fallback behavior, and correct round-trip handling of multimodal data.
New Features

🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

)
if os.path.exists(source):
return Image.open(source).convert("RGB")
with Image.open(source) as image:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code decodes the media from the path first, then reopens the same path to compute the hash. If the file changes between those two reads, the feature and hash may come from different contents, producing an incorrect radix cache key.

Please read the file bytes once and use the same bytes for both decoding and hashing.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed it in commit 9b2f20dd. Local file inputs now create a temporary snapshot while streaming the source file in chunks. The hash is computed over the exact bytes copied to that snapshot, and image/video/audio/Qwen-video decoding reads from the same snapshot.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


def _combine_mm_hashes(self, hashes: list[int | None], modality: str) -> int | None:
valid_hashes = [value for value in hashes if value is not None]
if not valid_hashes:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_combine_mm_hashes() silently drops None hashes. If any payload is processed successfully but its hash is unavailable, the combined hash may only cover part of the multimodal input instead of falling back to feature hashing.

I think this should return None whenever any input hash is None, so set_pad_value() can use the existing feature-hash fallback.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_combine_mm_hashes() now returns None if any input hash is unavailable, so MultimodalDataItem.set_pad_value() falls back to the existing feature hashing path instead of building a partial payload hash. Also, added a focused unit test for this case.

@pengchengneo
Copy link
Copy Markdown
Collaborator

LGTM

@pathfinder-pf
Copy link
Copy Markdown
Collaborator

@hashkanna please fix ci error

@hashkanna hashkanna force-pushed the kanna/mm-pad-value-content-hash branch from 94cf2b7 to 770d860 Compare May 14, 2026 10:53
@hashkanna
Copy link
Copy Markdown
Author

@hashkanna please fix ci error

@pathfinder-pf The failing job is the known flake in test_bench_serving_dense_tp_4 that #1077 just fixed. Rebasing onto main to pick up the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Optimize multimodal pad-value hashing for VLM requests

3 participants