Skip to content

Fix trace-bmm-fp8 test: B should be K-major for subword types#3184

Open
xrq-phys wants to merge 1 commit intoflashinfer-ai:mainfrom
xrq-phys:fix/trace-bmm-fp8
Open

Fix trace-bmm-fp8 test: B should be K-major for subword types#3184
xrq-phys wants to merge 1 commit intoflashinfer-ai:mainfrom
xrq-phys:fix/trace-bmm-fp8

Conversation

@xrq-phys
Copy link
Copy Markdown

@xrq-phys xrq-phys commented Apr 26, 2026

📌 Description

Closes #3188

Issue: Upstream change has introduced a failing CI test case: tests/trace/test_reference_correctness.py::test_bmm_fp8_reference_correctness

Cause: flashinfer.bmm_bf16, flashinfer.bmm_fp8 (any sub-32 dtypes) expect K-major inputs. The cutlass backend checks for this but the default fp8 backend doesn't, causing wrong results.

🔍 Related Issues

Current CI runs

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Improved correctness validation for BF16 and FP8 matrix operations by adjusting test inputs to align with kernel expectations.
    • Kept original reference comparisons unchanged to ensure consistent validation.
    • Preserved existing behavior for skipping when kernels are unavailable and retained the same closeness thresholds.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 26, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 25e3b61a-6e00-4f0b-9a76-ae2be0fed923

📥 Commits

Reviewing files that changed from the base of the PR and between 162eca5 and cebc7a3.

📒 Files selected for processing (1)
  • tests/trace/test_reference_correctness.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/trace/test_reference_correctness.py

📝 Walkthrough

Walkthrough

The BMM correctness tests now create contiguity-preserving versions of the batched b operands (b_kmaj, b_fp8_kmaj) via transpose→contiguous→transpose before invoking flashinfer.bmm_bf16 / flashinfer.bmm_fp8; reference trace comparisons still use the original b inputs. No public APIs changed.

Changes

Cohort / File(s) Summary
Test Preprocessing
tests/trace/test_reference_correctness.py
Create layout-adjusted tensors for b (b_kmaj, b_fp8_kmaj) using transpose→contiguous→transpose; call flashinfer.bmm_bf16 / flashinfer.bmm_fp8 with these adjusted tensors while keeping reference traces computed from original b/b_fp8. Exception/skip logic and closeness checks unchanged.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

op: gemm

Suggested reviewers

  • saltyminty
  • bkryu
  • aleozlx
  • sricketts
  • yongwww
  • yzh119
  • cyx-6

Poem

🐰 I hop through tensors, neat and spry,
I flip and hold them, nice and dry,
I make them solid, row by row,
Kernels dance where numbers flow,
A snack of bytes — then off I fly.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely summarizes the main change: fixing K-major layout requirement for B tensor in BMM tests for subword types (FP8, BF16).
Description check ✅ Passed Description includes issue closure, root cause analysis, and checked pre-commit/test verification items, meeting template requirements effectively.
Linked Issues check ✅ Passed Code changes enforce K-major layout for B tensor in BMM tests, directly addressing issue #3188's requirement to fix failing test by ensuring proper tensor layout for subword dtypes.
Out of Scope Changes check ✅ Passed All modifications are confined to test input preprocessing to enforce K-major layout, staying within scope of fixing the BMM correctness test without unrelated changes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@xrq-phys
Copy link
Copy Markdown
Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !605 has been created, and the CI pipeline #49554921 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/trace/test_reference_correctness.py (1)

2170-2172: Correct K-major preprocessing for bmm_bf16.

The transpose(1,2).contiguous().transpose(1,2) idiom produces a (B, K, N) view with stride (N*K, 1, K) (K-stride = 1), which matches the column-major layout bmm_bf16 documents for B. Logical values are preserved, so passing original b to the reference remains correct.

Optional: a one-line comment would help future readers understand why the seemingly-noop pattern is necessary — i.e., sub-32-bit BMMs require K-major B, but only the cutlass backend enforces it.

📝 Optional clarifying comment
     a = torch.randn(B, M, K, dtype=torch.bfloat16, device="cuda")
     b = torch.randn(B, K, N, dtype=torch.bfloat16, device="cuda")
+    # bmm_bf16 requires B in K-major (column-major) layout; round-trip through
+    # contiguous() to get strides (N*K, 1, K) without changing logical values.
     b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_reference_correctness.py` around lines 2170 - 2172, The
pre-processing using b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2) is
a no-op for logical values but was used to get K-major strides required only by
the cutlass backend; update the test to pass the original b (not b_kmaj) to the
reference path and keep the cutlass call as-is (api = flashinfer.bmm_bf16(a,
b_kmaj, backend="cutlass")), and add a one-line comment near b_kmaj explaining
that the transpose/contiguous/transpose is only to enforce K-major memory layout
for cutlass and that logical values are unchanged so the reference uses the
original b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/trace/test_reference_correctness.py`:
- Around line 2170-2172: The pre-processing using b_kmaj = b.transpose(1,
2).contiguous().transpose(1, 2) is a no-op for logical values but was used to
get K-major strides required only by the cutlass backend; update the test to
pass the original b (not b_kmaj) to the reference path and keep the cutlass call
as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a
one-line comment near b_kmaj explaining that the transpose/contiguous/transpose
is only to enforce K-major memory layout for cutlass and that logical values are
unchanged so the reference uses the original b.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c60b9edf-6ad1-467f-958f-dcc573d8be88

📥 Commits

Reviewing files that changed from the base of the PR and between 5e1318c and 1c1a4bd.

📒 Files selected for processing (1)
  • tests/trace/test_reference_correctness.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the bmm_bf16 and bmm_fp8 reference correctness tests to utilize K-major layout tensors when calling the FlashInfer API. Feedback suggests also using these K-major tensors in the reference implementation calls to maintain consistency and ensure that both the kernel and the reference are tested against the same memory representation.

Comment thread tests/trace/test_reference_correctness.py
Comment thread tests/trace/test_reference_correctness.py
@xrq-phys
Copy link
Copy Markdown
Author

@saltyminty could you approving / merging this PR?

#2711 SageAttn (presumably other CI runs also) is blocked by this failure.

CC @YangXu1990uiuc for vis.

Thanks!

@xrq-phys
Copy link
Copy Markdown
Author

/bot help

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

FlashInfer CI Bot

Available Commands:

  • /bot run - Mirror this PR to GitLab and run CI pipeline
  • /bot status - Check current pipeline status
  • /bot stop - Cancel running pipeline
  • /bot help - Show this help message

Authorization:

Only whitelisted users can trigger CI. Contact a maintainer for access.

How It Works:

  1. Authorized user comments /bot run on a PR
  2. Bot mirrors PR to internal GitLab
  3. GitLab CI pipeline runs automatically
  4. Results are posted back to this PR

Note: Any whitelisted user can trigger CI for any PR, not just their own.

@saltyminty
Copy link
Copy Markdown
Collaborator

CI looks good (failures are node allocation timeouts)

@saltyminty saltyminty self-assigned this Apr 27, 2026
@saltyminty saltyminty enabled auto-merge (squash) April 27, 2026 22:56
@xrq-phys
Copy link
Copy Markdown
Author

@saltyminty can we skip CI here? Or do we have to wait until nodes are back?

@saltyminty saltyminty disabled auto-merge April 28, 2026 05:32
@saltyminty
Copy link
Copy Markdown
Collaborator

We can skip internal CI since this change should be safe, but need the pre-merge checks to pass before the merge button appears

Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]test_bmm_fp8_reference_correctness fails with cos_sim=-0.0019 < 0.99

3 participants