|
13 | 13 | import torch |
14 | 14 | from parameterized import parameterized |
15 | 15 |
|
16 | | -from torchao.utils import is_sm_at_least_90 |
| 16 | +from torchao.utils import is_sm_at_least_90, torch_version_at_least |
17 | 17 |
|
18 | 18 | logging.basicConfig(level=logging.INFO) |
19 | 19 |
|
@@ -96,5 +96,88 @@ def test_int_scaled_mm(self, device, dtype): |
96 | 96 | torch.testing.assert_allclose(out32_1, out32_2) |
97 | 97 |
|
98 | 98 |
|
| 99 | +class TestIntScaledMatmulCPUPaths(unittest.TestCase): |
| 100 | + """ |
| 101 | + Tests for the CPU-specific paths inside _int_scaled_matmul_cpu. |
| 102 | + Because the u8s8 VNNI branch is gated on runtime CPU feature detection, |
| 103 | + CI machines are unlikely to exercise it naturally. We monkeypatch the |
| 104 | + two helper functions so each branch can be tested on any machine. |
| 105 | + """ |
| 106 | + |
| 107 | + def _make_inputs(self, m=64, k=32, n=16, dtype=torch.bfloat16): |
| 108 | + a = torch.randint(-128, 127, (m, k), dtype=torch.int8) |
| 109 | + b = torch.randint(-128, 127, (k, n), dtype=torch.int8) |
| 110 | + scales = torch.randn(m, 1, dtype=dtype) |
| 111 | + return a, b, scales |
| 112 | + |
| 113 | + def _reference(self, a, b, scales): |
| 114 | + from torchao.kernel.intmm import safe_int_mm |
| 115 | + |
| 116 | + return safe_int_mm(a, b).to(scales.dtype) * scales |
| 117 | + |
| 118 | + @unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+") |
| 119 | + def test_vnni_path_via_monkeypatch(self): |
| 120 | + """Force the u8s8 VNNI branch and verify against the reference result.""" |
| 121 | + import torchao.kernel.intmm as intmm_mod |
| 122 | + |
| 123 | + a, b, scales = self._make_inputs() |
| 124 | + expected = self._reference(a, b, scales) |
| 125 | + |
| 126 | + orig_amx = intmm_mod._cpu_is_amx_tile_supported |
| 127 | + orig_vnni = intmm_mod._cpu_is_vnni_supported |
| 128 | + try: |
| 129 | + # Simulate: no AMX, but VNNI present → u8s8 compensation path |
| 130 | + intmm_mod._cpu_is_amx_tile_supported = lambda: False |
| 131 | + intmm_mod._cpu_is_vnni_supported = lambda: True |
| 132 | + result = intmm_mod._int_scaled_matmul_cpu(a, b, scales) |
| 133 | + finally: |
| 134 | + intmm_mod._cpu_is_amx_tile_supported = orig_amx |
| 135 | + intmm_mod._cpu_is_vnni_supported = orig_vnni |
| 136 | + |
| 137 | + torch.testing.assert_close(result, expected) |
| 138 | + |
| 139 | + @unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+") |
| 140 | + def test_amx_path_via_monkeypatch(self): |
| 141 | + """Force the s8s8 AMX/fallback branch and verify against the reference result.""" |
| 142 | + import torchao.kernel.intmm as intmm_mod |
| 143 | + |
| 144 | + a, b, scales = self._make_inputs() |
| 145 | + expected = self._reference(a, b, scales) |
| 146 | + |
| 147 | + orig_amx = intmm_mod._cpu_is_amx_tile_supported |
| 148 | + orig_vnni = intmm_mod._cpu_is_vnni_supported |
| 149 | + try: |
| 150 | + # Simulate: AMX present → s8s8 direct path (no compensation) |
| 151 | + intmm_mod._cpu_is_amx_tile_supported = lambda: True |
| 152 | + intmm_mod._cpu_is_vnni_supported = lambda: False |
| 153 | + result = intmm_mod._int_scaled_matmul_cpu(a, b, scales) |
| 154 | + finally: |
| 155 | + intmm_mod._cpu_is_amx_tile_supported = orig_amx |
| 156 | + intmm_mod._cpu_is_vnni_supported = orig_vnni |
| 157 | + |
| 158 | + torch.testing.assert_close(result, expected) |
| 159 | + |
| 160 | + @unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+") |
| 161 | + def test_no_simd_path_via_monkeypatch(self): |
| 162 | + """Force the no-AMX/no-VNNI branch and verify against the reference result.""" |
| 163 | + import torchao.kernel.intmm as intmm_mod |
| 164 | + |
| 165 | + a, b, scales = self._make_inputs() |
| 166 | + expected = self._reference(a, b, scales) |
| 167 | + |
| 168 | + orig_amx = intmm_mod._cpu_is_amx_tile_supported |
| 169 | + orig_vnni = intmm_mod._cpu_is_vnni_supported |
| 170 | + try: |
| 171 | + # Simulate: neither AMX nor VNNI → s8s8 reference path |
| 172 | + intmm_mod._cpu_is_amx_tile_supported = lambda: False |
| 173 | + intmm_mod._cpu_is_vnni_supported = lambda: False |
| 174 | + result = intmm_mod._int_scaled_matmul_cpu(a, b, scales) |
| 175 | + finally: |
| 176 | + intmm_mod._cpu_is_amx_tile_supported = orig_amx |
| 177 | + intmm_mod._cpu_is_vnni_supported = orig_vnni |
| 178 | + |
| 179 | + torch.testing.assert_close(result, expected) |
| 180 | + |
| 181 | + |
99 | 182 | if __name__ == "__main__": |
100 | 183 | unittest.main() |
0 commit comments