Skip to content

Support both 3D and 4D kv_cache shapes in MLA APIs#2334

Merged
yzh119 merged 1 commit intomainfrom
claude/issue-2258-20260112-0227
Jan 13, 2026
Merged

Support both 3D and 4D kv_cache shapes in MLA APIs#2334
yzh119 merged 1 commit intomainfrom
claude/issue-2258-20260112-0227

Conversation

@yzh119
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 commented Jan 12, 2026

Modified MLA APIs to accept both 3D and 4D kv_cache tensor formats for backward compatibility.

Changes:

  • Modified _check_trtllm_gen_mla_shape() to accept both 3D and 4D tensors
  • Auto-normalize 3D tensors to 4D format internally
  • Updated docstrings for trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla()

Supported formats:

  • 3D: [num_pages, page_size, head_dim_ckv + head_dim_kpe]
  • 4D: [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]

Fixes #2258

Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes
    • Improved backward compatibility for cache dimension handling with automatic format normalization.
    • Enhanced error messages for cache validation.

✏️ Tip: You can customize this high-level summary in your review settings.

- Modified _check_trtllm_gen_mla_shape() to accept both 3D and 4D tensors
- Auto-normalize 3D tensors to 4D format for backward compatibility
- Updated docstrings for trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla()
- Both formats now supported: [num_pages, page_size, head_dim] and [num_pages, 1, page_size, head_dim]

Fixes #2258

Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

The changes add backward-compatible handling for KV cache tensor dimensions in MLA functions. The _check_trtllm_gen_mla_shape function now promotes 3D KV cache tensors to 4D via unsqueeze and returns the normalized tensor. Calling functions in batch decode operations are updated to use the returned, normalized KV cache and extract block_size before validation.

Changes

Cohort / File(s) Summary
MLA KV Cache Normalization
flashinfer/mla.py
Added 3D-to-4D KV cache tensor shape normalization in _check_trtllm_gen_mla_shape; updated trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla (both 3D/4D paths) to normalize KV cache via this helper and assign returned tensor. Moved block_size extraction prior to validation in affected functions.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

  • PR #2163: Directly related as this PR modifies the same functions that were moved/introduced in #2163, implementing backward-compatible KV cache shape handling at the call sites.

Suggested reviewers

  • cyx-6
  • nvmbreughe

Poem

🐰 Tensors hop through shapes with grace,
From 3D to 4D, finding their place,
Backward-compat blooms so bright,
KV caches normalized just right!

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support both 3D and 4D kv_cache shapes in MLA APIs' is concise and directly describes the main change—adding support for both tensor shape formats.
Description check ✅ Passed The PR description covers the main changes, supported formats, and references the fixed issue, though it lacks explicit references to test updates and pre-commit checklist items from the template.
Linked Issues check ✅ Passed The PR successfully addresses issue #2258 by modifying _check_trtllm_gen_mla_shape() to accept both 3D and 4D kv_cache formats, auto-normalizing 3D to 4D, and updating docstrings to document both supported formats.
Out of Scope Changes check ✅ Passed All changes in flashinfer/mla.py are directly related to resolving the documented inconsistency between the API docstring and the internal implementation regarding kv_cache tensor shapes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2062dec and 4f4476a.

📒 Files selected for processing (1)
  • flashinfer/mla.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/mla.py
🪛 Ruff (0.14.10)
flashinfer/mla.py

84-84: Avoid specifying long messages outside the exception class

(TRY003)


636-636: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Deploy Docs
  • GitHub Check: claude-review
🔇 Additional comments (5)
flashinfer/mla.py (5)

78-84: LGTM! Clean backward-compatible handling of 3D/4D kv_cache.

The dimension promotion logic correctly transforms [num_pages, page_size, head_dim] to [num_pages, 1, page_size, head_dim] using unsqueeze(1). The error message is appropriately descriptive for debugging.


121-122: Return value correctly propagates the normalized tensor.

The function now returns the potentially-modified kv_cache, which callers must capture. Verified that both trtllm_batch_decode_with_kv_cache_mla (line 639) and xqa_batch_decode_with_kv_cache_mla (line 764) properly reassign the return value.


544-544: Docstring correctly documents both supported formats.

This addresses the core issue (#2258) by documenting that both 3D and 4D kv_cache formats are accepted, with automatic normalization to 4D internally.


631-648: Correct use of negative indexing for format-agnostic block_size extraction.

Using kv_cache.size(-2) correctly retrieves page_size from both 3D [num_pages, page_size, dim] and 4D [num_pages, 1, page_size, dim] formats. The reassignment of kv_cache from the validation function ensures subsequent operations use the normalized 4D tensor.


749-773: Consistent handling with trtllm-gen backend.

The XQA backend mirrors the same pattern: extracting block_size using negative indexing and normalizing kv_cache via _check_trtllm_gen_mla_shape. This ensures both backends handle 3D/4D inputs uniformly.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces enhanced flexibility to the MLA APIs by enabling them to process kv_cache tensors in either a 3D or 4D shape. This change is primarily aimed at maintaining backward compatibility with existing implementations while standardizing the internal representation of kv_cache to a 4D format. The update ensures that functions can seamlessly handle different input dimensions without requiring callers to explicitly reshape their tensors.

Highlights

  • Flexible kv_cache Support: MLA APIs now accept kv_cache tensors in both 3D and 4D formats, ensuring backward compatibility.
  • Automatic Normalization: 3D kv_cache tensors are automatically converted to the 4D format internally for consistent processing.
  • Updated Documentation: Docstrings for trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla() have been updated to reflect the new supported kv_cache shapes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@claude
Copy link
Copy Markdown

claude Bot commented Jan 12, 2026

Code Review

Thank you for addressing issue #2258! This PR successfully resolves the mismatch between documentation and implementation for the MLA KV cache tensor format. Here's my analysis:

✅ Strengths

  1. Backward Compatibility: The approach maintains backward compatibility by accepting both 3D and 4D formats, which is ideal for existing users who may have code using either format.

  2. Clean Implementation: The normalization happens in a single validation function (_check_trtllm_gen_mla_shape()), returning the normalized tensor. This centralizes the logic and makes it easy to maintain.

  3. Comprehensive Documentation: Updated docstrings for both trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla() clearly document the dual format support.

  4. Dimension Extraction Logic: The code correctly extracts block_size using .size(-2) before normalization (flashinfer/mla.py:632, 748), which works for both 3D and 4D inputs.

  5. Return Value: The function now returns the normalized kv_cache, and both call sites properly capture it (flashinfer/mla.py:639, 764).

🔍 Potential Issues

  1. Test Coverage Gap: The existing tests still use .unsqueeze(1) to convert 3D to 4D (test_trtllm_gen_mla.py:312, test_xqa_mla_batch_decode.py:107). Consider adding tests that:

    • Pass 3D tensors directly without .unsqueeze(1) to verify the new path works
    • Verify that both 3D and 4D inputs produce identical outputs

    Example test addition:

    # Test with 3D kv_cache (new backward compatibility path)
    output_3d = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
        query=query,
        kv_cache=kv_cache,  # 3D: [num_pages, page_size, head_dim]
        ...
    )
    
    # Test with 4D kv_cache (existing path)
    output_4d = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
        query=query,
        kv_cache=kv_cache.unsqueeze(1),  # 4D: [num_pages, 1, page_size, head_dim]
        ...
    )
    
    torch.testing.assert_close(output_3d, output_4d)
  2. Minor: Inline Comment Clarity: The comment on line 79 says "Support both 3D and 4D kv_cache for backward compatibility". Consider clarifying which format was the original intent vs. which exists for compatibility. Based on the issue, it appears 3D was the documented format but 4D was required by implementation. The comment could be:

    # Support both 3D (documented format) and 4D (previously required) for backward compatibility
  3. Performance Consideration (Minor): The .unsqueeze(1) operation creates a view with minimal overhead, but it's worth noting that this happens on every call with 3D input. This is likely negligible, but if performance profiling reveals issues, consider documenting that 4D format avoids this overhead.

📋 Code Quality

  • Style: Follows FlashInfer conventions ✅
  • Error Messages: Clear and informative ✅
  • Framework Separation: Properly uses PyTorch operations in the appropriate layer (csrc/binding layer, not in include/) ✅
  • Documentation: Docstrings updated appropriately ✅

🔒 Security & Safety

No security concerns. The validation logic properly checks tensor dimensions before processing.

📊 Recommendation

LGTM with suggestion: The PR is functionally correct and ready to merge. I strongly recommend adding test cases that exercise the 3D input path directly to prevent future regressions.


Review generated by Claude Code

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for both 3D and 4D kv_cache shapes in MLA APIs for backward compatibility. It achieves this by normalizing 3D tensors to 4D within the _check_trtllm_gen_mla_shape helper function and updating the call sites to handle the returned normalized tensor. The docstrings have also been updated accordingly. The implementation is correct and well-contained. I've suggested a minor refactoring to improve code clarity. A key area for improvement is in testing; the existing tests do not seem to cover the new 3D input path, which is crucial for verifying the backward compatibility feature. Please consider adding test cases that pass a 3D kv_cache tensor directly to the modified APIs to ensure the normalization logic works as expected.

Comment thread flashinfer/mla.py
Comment on lines +80 to +84
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)
elif kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic is correct. For improved clarity and to separate validation from transformation, you could first check if the dimension is valid and then perform the normalization. This makes the intent of each block of code clearer.

Suggested change
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)
elif kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}")
if kv_cache.ndim not in (3, 4):
raise ValueError(f"Expected kv_cache.ndim to be 3 or 4, got {kv_cache.ndim}")
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)

@yzh119 yzh119 merged commit c4a3172 into main Jan 13, 2026
9 checks passed
@yzh119 yzh119 deleted the claude/issue-2258-20260112-0227 branch January 13, 2026 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect trtllm_batch_decode_with_kv_cache_mla API doc string

2 participants