Support both 3D and 4D kv_cache shapes in MLA APIs#2334
Conversation
- Modified _check_trtllm_gen_mla_shape() to accept both 3D and 4D tensors - Auto-normalize 3D tensors to 4D format for backward compatibility - Updated docstrings for trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla() - Both formats now supported: [num_pages, page_size, head_dim] and [num_pages, 1, page_size, head_dim] Fixes #2258 Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes add backward-compatible handling for KV cache tensor dimensions in MLA functions. The Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used📓 Path-based instructions (1)flashinfer/**/*.py📄 CodeRabbit inference engine (CLAUDE.md)
Files:
🪛 Ruff (0.14.10)flashinfer/mla.py84-84: Avoid specifying long messages outside the exception class (TRY003) 636-636: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
🔇 Additional comments (5)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces enhanced flexibility to the MLA APIs by enabling them to process Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Code ReviewThank you for addressing issue #2258! This PR successfully resolves the mismatch between documentation and implementation for the MLA KV cache tensor format. Here's my analysis: ✅ Strengths
🔍 Potential Issues
📋 Code Quality
🔒 Security & SafetyNo security concerns. The validation logic properly checks tensor dimensions before processing. 📊 RecommendationLGTM with suggestion: The PR is functionally correct and ready to merge. I strongly recommend adding test cases that exercise the 3D input path directly to prevent future regressions. Review generated by Claude Code |
There was a problem hiding this comment.
Code Review
This pull request adds support for both 3D and 4D kv_cache shapes in MLA APIs for backward compatibility. It achieves this by normalizing 3D tensors to 4D within the _check_trtllm_gen_mla_shape helper function and updating the call sites to handle the returned normalized tensor. The docstrings have also been updated accordingly. The implementation is correct and well-contained. I've suggested a minor refactoring to improve code clarity. A key area for improvement is in testing; the existing tests do not seem to cover the new 3D input path, which is crucial for verifying the backward compatibility feature. Please consider adding test cases that pass a 3D kv_cache tensor directly to the modified APIs to ensure the normalization logic works as expected.
| if kv_cache.ndim == 3: | ||
| # [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe] | ||
| kv_cache = kv_cache.unsqueeze(1) | ||
| elif kv_cache.ndim != 4: | ||
| raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}") |
There was a problem hiding this comment.
This logic is correct. For improved clarity and to separate validation from transformation, you could first check if the dimension is valid and then perform the normalization. This makes the intent of each block of code clearer.
| if kv_cache.ndim == 3: | |
| # [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe] | |
| kv_cache = kv_cache.unsqueeze(1) | |
| elif kv_cache.ndim != 4: | |
| raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}") | |
| if kv_cache.ndim not in (3, 4): | |
| raise ValueError(f"Expected kv_cache.ndim to be 3 or 4, got {kv_cache.ndim}") | |
| if kv_cache.ndim == 3: | |
| # [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe] | |
| kv_cache = kv_cache.unsqueeze(1) |
Modified MLA APIs to accept both 3D and 4D kv_cache tensor formats for backward compatibility.
Changes:
_check_trtllm_gen_mla_shape()to accept both 3D and 4D tensorstrtllm_batch_decode_with_kv_cache_mla()andxqa_batch_decode_with_kv_cache_mla()Supported formats:
3D: [num_pages, page_size, head_dim_ckv + head_dim_kpe]4D: [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]Fixes #2258
Generated with Claude Code
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.