Skip to content

Commit 89373b8

Browse files
authored
fix xpu woq linear dtype (#1506)
* fix xpu dtypoe Signed-off-by: jiqing-feng <[email protected]> * fix nf4 dtype Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent fdbbfb6 commit 89373b8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

bitsandbytes/backends/cpu_xpu_common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ def gemm_4bit_impl(
552552
GEMM output tensor.
553553
"""
554554
if getattr(state, "ipex", False):
555+
# compute_dtype: 1 indicates fp16, 2 indicates bf16
556+
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
555557
output = torch.ops.torch_ipex.woq_linear(
556558
A,
557559
B,
@@ -562,7 +564,7 @@ def gemm_4bit_impl(
562564
None,
563565
None,
564566
state.blocksize,
565-
ipex_cpu.quantization.WoqLowpMode.BF16,
567+
compute_dtype,
566568
1,
567569
state.compensation,
568570
)

0 commit comments

Comments
 (0)