Skip to content

Add batch invariant grouped gemm kernel for inference#3783

Open
santhnm2 wants to merge 3 commits intoNVIDIA:mainfrom
santhnm2:batch_invariant_grouped_gemm
Open

Add batch invariant grouped gemm kernel for inference#3783
santhnm2 wants to merge 3 commits intoNVIDIA:mainfrom
santhnm2:batch_invariant_grouped_gemm

Conversation

@santhnm2
Copy link
Contributor

@santhnm2 santhnm2 commented Mar 10, 2026

What does this PR do ?

Adds a batch invariant grouped gemm kernel for bf16 inference implemented in Triton.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@santhnm2 santhnm2 marked this pull request as ready for review March 10, 2026 20:09
@santhnm2 santhnm2 requested review from a team as code owners March 10, 2026 20:09
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 10, 2026 20:09
@santhnm2
Copy link
Contributor Author

/claude review

Copy link
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light review — code looks good overall. The kernel logic, stride handling for both trans_b modes, and the persistent-style tile scheduling are correct. The fallback stub for is_batch_invariant_mode_enabled() when the import fails is a clean pattern. Test coverage is thorough with correctness, batch invariance, determinism, and EP/TP integration tests.

One minor observation: HAVE_BATCH_INVARIANT is defined but never referenced — the dispatch relies entirely on the is_batch_invariant_mode_enabled() fallback stub. Unlike HAVE_FLASHINFER (which is used for assertions and guards elsewhere), HAVE_BATCH_INVARIANT is dead code. Consider either removing it or adding a guard (e.g., an assertion in _triton_batch_invariant_forward) for consistency with the HAVE_FLASHINFER pattern.

LGTM otherwise.


BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64

bs_cpu = batch_sizes.cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work with cuda-graphs. Should we disable them in transformer config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea will add some assertion that it doesn't work with cuda graphs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants