Fix trace-bmm-fp8 test: B should be K-major for subword types#3184
Fix trace-bmm-fp8 test: B should be K-major for subword types#3184xrq-phys wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe BMM correctness tests now create contiguity-preserving versions of the batched Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/trace/test_reference_correctness.py (1)
2170-2172: Correct K-major preprocessing forbmm_bf16.The
transpose(1,2).contiguous().transpose(1,2)idiom produces a(B, K, N)view with stride(N*K, 1, K)(K-stride = 1), which matches the column-major layoutbmm_bf16documents forB. Logical values are preserved, so passing originalbto the reference remains correct.Optional: a one-line comment would help future readers understand why the seemingly-noop pattern is necessary — i.e., sub-32-bit BMMs require K-major
B, but only thecutlassbackend enforces it.📝 Optional clarifying comment
a = torch.randn(B, M, K, dtype=torch.bfloat16, device="cuda") b = torch.randn(B, K, N, dtype=torch.bfloat16, device="cuda") + # bmm_bf16 requires B in K-major (column-major) layout; round-trip through + # contiguous() to get strides (N*K, 1, K) without changing logical values. b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_reference_correctness.py` around lines 2170 - 2172, The pre-processing using b_kmaj = b.transpose(1, 2).contiguous().transpose(1, 2) is a no-op for logical values but was used to get K-major strides required only by the cutlass backend; update the test to pass the original b (not b_kmaj) to the reference path and keep the cutlass call as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a one-line comment near b_kmaj explaining that the transpose/contiguous/transpose is only to enforce K-major memory layout for cutlass and that logical values are unchanged so the reference uses the original b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/trace/test_reference_correctness.py`:
- Around line 2170-2172: The pre-processing using b_kmaj = b.transpose(1,
2).contiguous().transpose(1, 2) is a no-op for logical values but was used to
get K-major strides required only by the cutlass backend; update the test to
pass the original b (not b_kmaj) to the reference path and keep the cutlass call
as-is (api = flashinfer.bmm_bf16(a, b_kmaj, backend="cutlass")), and add a
one-line comment near b_kmaj explaining that the transpose/contiguous/transpose
is only to enforce K-major memory layout for cutlass and that logical values are
unchanged so the reference uses the original b.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c60b9edf-6ad1-467f-958f-dcc573d8be88
📒 Files selected for processing (1)
tests/trace/test_reference_correctness.py
There was a problem hiding this comment.
Code Review
This pull request modifies the bmm_bf16 and bmm_fp8 reference correctness tests to utilize K-major layout tensors when calling the FlashInfer API. Feedback suggests also using these K-major tensors in the reference implementation calls to maintain consistency and ensure that both the kernel and the reference are tested against the same memory representation.
|
@saltyminty could you approving / merging this PR? #2711 SageAttn (presumably other CI runs also) is blocked by this failure. CC @YangXu1990uiuc for vis. Thanks! |
|
/bot help |
FlashInfer CI BotAvailable Commands:
Authorization:Only whitelisted users can trigger CI. Contact a maintainer for access. How It Works:
Note: Any whitelisted user can trigger CI for any PR, not just their own. |
|
CI looks good (failures are node allocation timeouts) |
1c1a4bd to
162eca5
Compare
|
@saltyminty can we skip CI here? Or do we have to wait until nodes are back? |
|
We can skip internal CI since this change should be safe, but need the pre-merge checks to pass before the merge button appears |
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
162eca5 to
cebc7a3
Compare
📌 Description
Closes #3188
Issue: Upstream change has introduced a failing CI test case:
tests/trace/test_reference_correctness.py::test_bmm_fp8_reference_correctnessCause:
flashinfer.bmm_bf16,flashinfer.bmm_fp8(any sub-32 dtypes) expect K-major inputs. Thecutlassbackend checks for this but the default fp8 backend doesn't, causing wrong results.🔍 Related Issues
Current CI runs
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit