Skip to content

Commit bd2b033

Browse files
authored
Added the device version checks (#2307)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ x] I have installed the hooks with `pre-commit install`. - [ x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ x] Tests have been added or updated as needed. - [ x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Improved test suite with a refined hardware check: an FP8-related test now requires a specific GPU compute capability so it only runs on compatible hardware, reducing false skips and improving reliability. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 866773b commit bd2b033

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

tests/attention/test_cudnn_prefill.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import flashinfer
55
import cudnn
66

7+
from flashinfer.utils import get_compute_capability
8+
79

810
@pytest.mark.parametrize("batch_size", [1, 4])
911
@pytest.mark.parametrize("s_qo", [8, 17, 700])
@@ -214,6 +216,13 @@ def test_cudnn_prefill_fp8(
214216
torch.manual_seed(seed)
215217
device = "cuda:0"
216218

219+
major, _ = get_compute_capability(torch.device(device))
220+
221+
if major != 10:
222+
pytest.skip(
223+
f"cuDNN FP8 prefill is not supported on compute capability {major}, skipping test"
224+
)
225+
217226
actual_seq_lens_q = torch.randint(
218227
1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
219228
)

0 commit comments

Comments
 (0)