Skip to content

Commit 78b6984

Browse files
committed
test(gpu): add regression test for family wildcard matching
This adds a unit test to verify the `device_matches_constraint` logic. It ensures that families with a trailing 'X' (e.g., "gfx110X") are correctly recognized as wildcards that match specific models (e.g., "gfx1103"). Co-authored-by: opencode:Gemma-4-12B-it-GGUF
1 parent 50d9fdd commit 78b6984

1 file changed

Lines changed: 62 additions & 0 deletions

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
"""
3+
CPU-runnable unit tests for device family matching logic (device_matches_constraint).
4+
5+
These tests replicate the C++ logic for matching a device family against a set of
6+
allowed families, including support for wildcard 'X' at the end of a family name.
7+
"""
8+
9+
import unittest
10+
11+
# ---------------------------------------------------------------------------
12+
# Python replica of system_info.cpp::device_matches_constraint()
13+
# ---------------------------------------------------------------------------
14+
15+
def device_matches_constraint(device_family: str, allowed_families: set) -> bool:
16+
if not allowed_families:
17+
return True # Empty = all families allowed
18+
19+
if device_family in allowed_families:
20+
return True
21+
22+
for af in allowed_families:
23+
if len(af) > 1 and af.endswith('X'):
24+
prefix = af[:-1]
25+
if device_family.startswith(prefix):
26+
return True
27+
28+
return False
29+
30+
# ---------------------------------------------------------------------------
31+
# Tests
32+
# ---------------------------------------------------------------------------
33+
34+
class TestDeviceFamilyMatching(unittest.TestCase):
35+
def test_wildcard_matching(self):
36+
# gfx1103 should match gfx110X
37+
self.assertTrue(device_matches_constraint("gfx1103", {"gfx110X"}))
38+
# gfx1201 should match gfx120X
39+
self.assertTrue(device_matches_constraint("gfx1201", {"gfx120X"}))
40+
41+
def test_exact_matching(self):
42+
# gfx1151 should match gfx1151
43+
self.assertTrue(device_matches_constraint("gfx1151", {"gfx1151"}))
44+
# gfx1152 should match gfx1152
45+
self.assertTrue(device_matches_constraint("gfx1152", {"gfx1152"}))
46+
47+
def test_non_matching(self):
48+
# gfx1151 should NOT match gfx110X
49+
self.assertFalse(device_matches_constraint("gfx1151", {"gfx110X"}))
50+
51+
def test_empty_allowed_families(self):
52+
# Empty allowed_families should match everything
53+
self.assertTrue(device_matches_constraint("gfx1151", set()))
54+
55+
def test_multiple_allowed_families(self):
56+
# Should match if any in the set match
57+
self.assertTrue(device_matches_constraint("gfx1103", {"gfx103X", "gfx110X"}))
58+
self.assertTrue(device_matches_constraint("gfx1201", {"gfx110X", "gfx120X"}))
59+
self.assertFalse(device_matches_constraint("gfx1151", {"gfx103X", "gfx110X"}))
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)