Skip to content

Commit 2ffa367

Browse files
Merge pull request #898 from jianan-gu/upstream_device_abstraction
Initial working draft to allow enable a device abstraction to enable different hardware backend.
2 parents fd9d072 + 9ff6c63 commit 2ffa367

File tree

7 files changed

+903
-675
lines changed

7 files changed

+903
-675
lines changed

bitsandbytes/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@
1212
matmul_cublas,
1313
mm_cublas,
1414
)
15+
from .cextension import lib
1516
from .nn import modules
16-
from .optim import adam
1717

18+
if lib and lib.compiled_with_cuda:
19+
from .backends import register_backend
20+
from .backends.cuda import CUDABackend
21+
from .optim import adam
22+
23+
register_backend("cuda", CUDABackend())
1824
__pdoc__ = {
1925
"libbitsandbytes": False,
2026
"optim.optimizer.Optimizer8bit": False,

bitsandbytes/backends/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Dict
2+
3+
from bitsandbytes.backends.base import Backend
4+
5+
backends: Dict[str, Backend] = {}
6+
7+
8+
def register_backend(backend_name: str, backend_instance: Backend):
9+
backends[backend_name.lower()] = backend_instance
10+
11+
12+
def ensure_backend_is_available(device_type: str):
13+
"""Check if a backend is available for the given device type."""
14+
if device_type.lower() not in backends:
15+
raise NotImplementedError(f"Device backend for {device_type} is currently not supported.")

bitsandbytes/backends/base.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
6+
from bitsandbytes.utils import QuantState
7+
8+
9+
class Backend(ABC):
10+
"""Base class for devices backends that will implement their own 8bits and 4bits functions."""
11+
12+
@abstractmethod
13+
def double_quant(
14+
self,
15+
A,
16+
col_stats=None,
17+
row_stats=None,
18+
out_col=None,
19+
out_row=None,
20+
threshold=0.0,
21+
):
22+
raise NotImplementedError
23+
24+
@abstractmethod
25+
def transform(
26+
self,
27+
A,
28+
to_order,
29+
from_order="row",
30+
out=None,
31+
transpose=False,
32+
state=None,
33+
ld=None,
34+
):
35+
raise NotImplementedError
36+
37+
@abstractmethod
38+
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
39+
raise NotImplementedError
40+
41+
@abstractmethod
42+
def mm_dequant(
43+
self,
44+
A,
45+
quant_state,
46+
row_stats,
47+
col_stats,
48+
out=None,
49+
new_row_stats=None,
50+
new_col_stats=None,
51+
bias=None,
52+
):
53+
raise NotImplementedError
54+
55+
@abstractmethod
56+
def extract_outliers(self, A, SA, idx):
57+
raise NotImplementedError
58+
59+
@abstractmethod
60+
def quantize_4bit(
61+
self,
62+
A: torch.Tensor,
63+
absmax: Optional[torch.Tensor] = None,
64+
out: Optional[torch.Tensor] = None,
65+
blocksize=64,
66+
compress_statistics=False,
67+
quant_type="fp4",
68+
quant_storage=torch.uint8,
69+
) -> Tuple[torch.Tensor, QuantState]:
70+
"""
71+
Quantize tensor A in blocks of 4-bit values.
72+
73+
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
74+
75+
Parameters
76+
----------
77+
A : torch.Tensor
78+
The input tensor.
79+
absmax : torch.Tensor
80+
The absmax values.
81+
out : torch.Tensor
82+
The output tensor.
83+
blocksize : int
84+
The blocksize used in quantization.
85+
quant_type : str
86+
The 4-bit quantization data type {fp4, nf4}
87+
88+
Returns
89+
-------
90+
torch.Tensor:
91+
Tensor with packed 4-bit values.
92+
tuple(torch.Tensor, torch.Size, torch.dtype, int):
93+
The quantization state to undo the quantization.
94+
"""
95+
raise NotImplementedError
96+
97+
@abstractmethod
98+
def dequantize_4bit(
99+
self,
100+
A: torch.Tensor,
101+
quant_state: Optional[QuantState] = None,
102+
absmax: Optional[torch.Tensor] = None,
103+
out: Optional[torch.Tensor] = None,
104+
blocksize: int = 64,
105+
quant_type="fp4",
106+
) -> torch.Tensor:
107+
"""
108+
Dequantizes FP4 blockwise quantized values.
109+
110+
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
111+
112+
Parameters
113+
----------
114+
A : torch.Tensor
115+
The input tensor (packed 4-bit values).
116+
quant_state : QuantState
117+
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
118+
absmax : torch.Tensor
119+
The absmax values.
120+
out : torch.Tensor
121+
Dequantized output tensor.
122+
blocksize : int
123+
The blocksize used in quantization.
124+
quant_type : str
125+
The 4-bit quantization data type {fp4, nf4}
126+
127+
128+
Returns
129+
-------
130+
torch.Tensor:
131+
Dequantized tensor.
132+
"""
133+
raise NotImplementedError

0 commit comments

Comments
 (0)