Skip to content

Commit 550cf70

Browse files
committed
ipex dep
1 parent 1192982 commit 550cf70

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

bitsandbytes/backends/xpu.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
)
1616
try:
1717
import intel_extension_for_pytorch as ipex
18+
ipex_xpu = ipex if ipex._C._has_xpu() else None
1819
except BaseException:
1920
ipex_xpu = None
2021

2122
Tensor = torch.Tensor
2223

23-
str2optimizer8bit_blockwise = {
24-
"adam": (
25-
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
26-
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16,
27-
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16,
28-
),
29-
}
24+
25+
str2optimizer8bit_blockwise = {}
26+
if ipex_xpu is not None:
27+
str2optimizer8bit_blockwise = {
28+
"adam": (
29+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
30+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16,
31+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16,
32+
),
33+
}
3034

3135

3236
def assert_on_xpu(tensors):
@@ -246,6 +250,8 @@ def optimizer_update_8bit_blockwise(
246250
skip_zeros=False,
247251
) -> None:
248252
optim_func = None
253+
if ipex_xpu is None:
254+
raise RuntimeError("Please install intel_extension_for_ipex for 8bit optimizer backend on XPU device.")
249255

250256
assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
251257

0 commit comments

Comments
 (0)