Skip to content

Commit 0cd87aa

Browse files
authored
enable quant storage (#1563)
* enable quant storage Signed-off-by: jiqing-feng <[email protected]> * fix to numpy Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 83c147d commit 0cd87aa

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

bitsandbytes/backends/cpu.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ def quantize_4bit(
137137
if blocksize is None:
138138
blocksize = 64
139139
assert_on_cpu([A, absmax, out])
140-
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
141-
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
140+
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage)
142141

143142
def dequantize_4bit(
144143
self,

bitsandbytes/backends/cpu_xpu_common.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def quantize_4bit_impl(
296296
blocksize=64,
297297
compress_statistics=False,
298298
quant_type="nf4",
299+
quant_storage=torch.uint8,
299300
) -> Tensor:
300301
"""
301302
Quantize tensor A in blocks of 4-bit values.
@@ -314,6 +315,8 @@ def quantize_4bit_impl(
314315
The blocksize used in quantization.
315316
quant_type : str
316317
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
318+
quant_storage: torch.dtype
319+
We can use bytes to convert storage type.
317320
318321
Returns
319322
-------
@@ -401,6 +404,10 @@ def quantize_4bit_impl(
401404
quant_type=quant_type,
402405
)
403406

407+
if quant_storage != torch.uint8:
408+
bytes_value = out.cpu().numpy().tobytes()
409+
out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device)
410+
404411
return out.reshape(-1, 1), state
405412

406413

@@ -418,7 +425,8 @@ def dequant_8bit(A, offset, quant_state):
418425
return absmax
419426

420427

421-
@_maybe_torch_compile
428+
# Compile will fail in torch.frombuffer
429+
# @_maybe_torch_compile
422430
def dequantize_4bit_impl(
423431
A: Tensor,
424432
quant_state=None,
@@ -453,6 +461,10 @@ def dequantize_4bit_impl(
453461
"""
454462
transpose = True if A.shape[0] == 1 else False
455463
A = A.reshape(-1)
464+
device = A.device
465+
if A.dtype != torch.uint8:
466+
bytes_value = A.cpu().numpy().tobytes()
467+
A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device)
456468

457469
if quant_state is None:
458470
assert absmax is not None and out is not None

bitsandbytes/backends/xpu.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ def quantize_4bit(
138138
if blocksize is None:
139139
blocksize = 64
140140
assert_on_xpu([A, absmax, out])
141-
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
142-
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
141+
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage)
143142
return output
144143

145144
def dequantize_4bit(

bitsandbytes/nn/modules.py

+1
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def set_ipex_linear(self, x: torch.Tensor):
498498
if (
499499
(x.device.type in ("cpu", "xpu"))
500500
and not getattr(self.weight.quant_state, "ipex", False)
501+
and self.weight.data.dtype == torch.uint8
501502
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
502503
and self.weight.quant_state.quant_type == "nf4"
503504
and not self.training

0 commit comments

Comments
 (0)