From 243f65dadd4d19bf48926a8fa9c989c73b623216 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Feb 2025 15:28:51 +0000 Subject: [PATCH] Enable XPU int matmul Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu_xpu_common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 8c1f30f10..ede1be69c 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -194,8 +194,10 @@ def int8_linear_matmul_impl( A_reshaped = A.reshape(m, k) - # torch._int_mm is available on CPU since torch 2.4 - if _torch_version_prereq(2, 4) and A.device.type == "cpu": + # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6 + if (A.device.type == "cpu" and _torch_version_prereq(2, 4)) or ( + A.device.type == "xpu" and _torch_version_prereq(2, 6) + ): C = torch._int_mm(A_reshaped, B.T).to(dtype) else: C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)