Skip to content

Commit 5ea4afe

Browse files
CPU: workaround avx512 4bit dequantize accuracy issue for large blocksize (#1828)
1 parent 4d19869 commit 5ea4afe

File tree

1 file changed

+9
-1
lines changed
  • bitsandbytes/backends/cpu

1 file changed

+9
-1
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
logger = logging.getLogger(__name__)
1414

15+
_has_avx512 = torch.backends.cpu.get_cpu_capability() == "AVX512"
16+
1517
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
1618
# However, we can overflow if we use this without AVX512_VNNI support.
1719
# This is fixed in torch 2.6+, so we set this as the minimum to be safe.
@@ -134,8 +136,14 @@ def _(
134136
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
135137
)
136138

139+
# Fallback as AVX512 implementation has accuracy issues with fp16/fp32 and blocksize >= 2048
140+
# Note: this is not a common use case.
141+
avx512_fallback = _has_avx512 and blocksize >= 2048 and dtype != torch.bfloat16
142+
137143
# Odd shape is not supported by this kernel; fallback to generic implementation
138-
if shape[-1] % 2 != 0:
144+
shape_fallback = shape[-1] % 2 != 0
145+
146+
if avx512_fallback or shape_fallback:
139147
from ..default.ops import _dequantize_4bit_impl
140148

141149
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)

0 commit comments

Comments
 (0)