We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2354bdd commit 249a3cdCopy full SHA for 249a3cd
bitsandbytes/backends/cpu_xpu_common.py
@@ -194,8 +194,10 @@ def int8_linear_matmul_impl(
194
195
A_reshaped = A.reshape(m, k)
196
197
- # torch._int_mm is available on CPU since torch 2.4
198
- if _torch_version_prereq(2, 4) and A.device.type == "cpu":
+ # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6
+ if (A.device.type == "cpu" and _torch_version_prereq(2, 4)) or (
199
+ A.device.type == "xpu" and _torch_version_prereq(2, 6)
200
+ ):
201
C = torch._int_mm(A_reshaped, B.T).to(dtype)
202
else:
203
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
0 commit comments