[fmhav2] skip fp8 tests and add warning#3050
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds a runtime Changes
Sequence Diagram(s)(omitted — changes are a warning + test skip/parameterization updates and do not introduce a new multi-component sequential flow) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a warning for FP8 (e4m3) kernels known to hang on SM90 architectures and updates the test suite to skip FP8-related test cases. Feedback was provided regarding the placement of the SM90 hang warning, suggesting it should be moved after the SM120 compatibility check to avoid misleading users on Blackwell hardware.
| query.dtype == torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else False | ||
| ) | ||
| if is_e4m3: | ||
| logging.warning("The FP8 (e4m3) kernels are currently known to hang on SM90.") |
There was a problem hiding this comment.
The warning about FP8 kernels hanging on SM90 is issued before the check for SM120 (Blackwell) support. If a user is on an SM120 device, they will see this warning before receiving a ValueError stating that FP8 is not yet supported on their architecture. This is confusing as the hang is specific to SM90. It would be better to move the warning after the SM120 check so it only appears for SM90 users.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/attention/test_fmha_v2_prefill.py (1)
785-788: Prefer explicit skipped params over commented-out cases.Using
pytest.param(..., marks=pytest.mark.skip(...))keeps FP8 cases visible in test reports while still avoiding hangs.Suggested refactor
`@pytest.mark.parametrize`( ("dtype", "o_dtype"), [ (torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16), - # todo(jimmyzho) skip all fp8 tests due to unmitigated hangs - # (torch.float8_e4m3fn, torch.float8_e4m3fn), - # (torch.float8_e4m3fn, torch.bfloat16), - # (torch.float8_e4m3fn, torch.float16), + pytest.param( + torch.float8_e4m3fn, + torch.float8_e4m3fn, + marks=pytest.mark.skip(reason="Known FP8 e4m3 hangs (tracked)"), + ), + pytest.param( + torch.float8_e4m3fn, + torch.bfloat16, + marks=pytest.mark.skip(reason="Known FP8 e4m3 hangs (tracked)"), + ), + pytest.param( + torch.float8_e4m3fn, + torch.float16, + marks=pytest.mark.skip(reason="Known FP8 e4m3 hangs (tracked)"), + ), ], ) @@ `@pytest.mark.parametrize`( ("dtype", "o_dtype"), [ (torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16), - # todo(jimmyzho) skip all fp8 tests due to unmitigated hangs - # (torch.float8_e4m3fn, torch.float16), + pytest.param( + torch.float8_e4m3fn, + torch.float16, + marks=pytest.mark.skip(reason="Known FP8 e4m3 hangs (tracked)"), + ), ], )Also applies to: 861-863
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_fmha_v2_prefill.py` around lines 785 - 788, Replace the commented-out FP8 param tuples with explicit pytest skip params so they remain visible in test reports; for each commented tuple (e.g. the FP8 cases near the block in tests/attention/test_fmha_v2_prefill.py and the similar block at lines ~861-863), add them back as pytest.param((torch.float8_e4m3fn, torch.float8_e4m3fn), marks=pytest.mark.skip(reason="skipping FP8 tests due to hangs")) (and likewise for the other two combinations) so the cases are present but skipped, preserving the original tuple values and adding a clear skip reason.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/attention/test_fmha_v2_prefill.py`:
- Around line 785-788: Replace the commented-out FP8 param tuples with explicit
pytest skip params so they remain visible in test reports; for each commented
tuple (e.g. the FP8 cases near the block in
tests/attention/test_fmha_v2_prefill.py and the similar block at lines
~861-863), add them back as pytest.param((torch.float8_e4m3fn,
torch.float8_e4m3fn), marks=pytest.mark.skip(reason="skipping FP8 tests due to
hangs")) (and likewise for the other two combinations) so the cases are present
but skipped, preserving the original tuple values and adding a clear skip
reason.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0b33414a-2524-46f6-b6e6-a53e458da02d
📒 Files selected for processing (2)
flashinfer/prefill.pytests/attention/test_fmha_v2_prefill.py
|
/bot run |
|
/bot run |
|
/bot run |
📌 Description
This PR re-enables the FMHA v2 prefill test suite while properly isolating the known FP8 hang issue.
Previously, the entire
test_fmha_v2_prefill.pytest file was skipped via a blanketpytestmark = pytest.mark.skip(...), which meant no FMHA v2 prefill tests ran at all — including non-FP8 configurations(float16, bfloat16) that work correctly.
Changes:
Removed the file-level
pytestmarkskip — non-FP8 tests (float16/bfloat16) now run again in CI.Commented out FP8 dtype parametrize entries (
float8_e4m3fn) in all test functions instead of skippingthem at runtime. This avoids test collection overhead for known-broken configurations and makes it clear which
combinations are disabled.
Removed now-redundant runtime skips — the per-case
pytest.skip()calls for FP8 sliding window bugs,FP8→FP8 output hangs, and sliding window hangs are no longer needed since those dtype combinations are no longer
parametrized.
Added a
logging.warning()intrtllm_fmha_v2_prefill()when FP8 e4m3 inputs are detected, alertingusers that these kernels are known to hang on SM90.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests