diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index d478c64cf..87ffc7360 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -194,8 +194,10 @@ def int8_linear_matmul_impl( A_reshaped = A.reshape(m, k) - # torch._int_mm is available on CPU since torch 2.4 - if _torch_version_prereq(2, 4) and A.device.type == "cpu": + # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6 + if (A.device.type == "cpu" and _torch_version_prereq(2, 4)) or ( + A.device.type == "xpu" and _torch_version_prereq(2, 6) + ): C = torch._int_mm(A_reshaped, B.T).to(dtype) else: C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)