Skip to content

Commit bbee99a

Browse files
Merge branch 'device-agnostic' of https://github.com/TimDettmers/bitsandbytes into device-agnostic
2 parents b2b2a25 + 71ba853 commit bbee99a

File tree

4 files changed

+126
-38
lines changed

4 files changed

+126
-38
lines changed

Diff for: bitsandbytes/_ops.py

+26
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@
1515
register_fake = torch.library.impl_abstract
1616
register_kernel = torch.library.impl
1717

18+
# Int8 mixed precision matmul + dequant + bias
19+
torch.library.define(
20+
"bitsandbytes::int8_mixed_scaled_mm",
21+
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)",
22+
)
23+
24+
25+
@register_fake("bitsandbytes::int8_mixed_scaled_mm")
26+
def _(
27+
A: torch.Tensor,
28+
CA: torch.Tensor,
29+
CB: torch.Tensor,
30+
SCA: torch.Tensor,
31+
SCB: torch.Tensor,
32+
outlier_cols: Optional[torch.Tensor] = None,
33+
bias: Optional[torch.Tensor] = None,
34+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35+
shapeC = (*CA.shape[:-1], CB.shape[0])
36+
37+
out = torch.empty(shapeC, device=A.device, dtype=A.dtype)
38+
39+
outlier_cols = torch.library.get_ctx().new_dynamic_size()
40+
subA = A.new_empty(outlier_cols, dtype=torch.int64)
41+
42+
return out, subA
43+
1844

1945
# Higher level op: int8 matmul + dequant + bias
2046
torch.library.define(

Diff for: bitsandbytes/autograd/_functions.py

+16-25
Original file line numberDiff line numberDiff line change
@@ -210,37 +210,28 @@ def forward(
210210
# 2. Quantize B
211211
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
212212

213-
# Handle sparse decomposition. In some instances, we may have not found any
214-
# outlier columns at all. In that case, we'll skip this part completely.
215-
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
213+
# Handle sparse decomposition
214+
if state.threshold > 0.0:
216215
state.idx = outlier_cols
217216

218-
# Zero out the outliers in the transposed 8bit inputs.
219-
if CAt is not None:
220-
CAt[:, state.idx] = 0
221-
222-
# Extract the input outliers in original precision
223-
subA = A[:, state.idx].contiguous()
217+
# Mixed Int8 Matmul + Dequant + Bias
218+
output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
219+
A,
220+
CA,
221+
state.CB,
222+
SCA,
223+
state.SCB,
224+
outlier_cols,
225+
bias,
226+
)
224227

225-
# Extract the corresponding weights
226-
if state.has_fp16_weights:
227-
state.subB = B[:, state.idx].t()
228-
else:
229-
# To dequantize our weights associated with the input outliers,
230-
# we want to divide by 127. It's however more performant to multiply
231-
# by the reciprocal.
232-
outliers = state.CB[:, state.idx]
233-
state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
234228
else:
229+
# Int8 Matmul + Dequant + Bias
230+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
231+
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
232+
)
235233
subA = None
236234

237-
# 3. Int8 Matmul + Dequant + Bias
238-
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
239-
240-
# 4. Mixed-precision decomposition matmul
241-
if subA is not None and state.subB is not None:
242-
output = output.addmm(subA, state.subB)
243-
244235
# 5. Save state
245236
ctx.state = state
246237

Diff for: bitsandbytes/backends/cuda/ops.py

+42
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2222
_int8_linear_matmul_impl(A, B, out)
2323

2424

25+
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
26+
def _(
27+
A: torch.Tensor,
28+
CA: torch.Tensor,
29+
CB: torch.Tensor,
30+
SCA: torch.Tensor,
31+
SCB: torch.Tensor,
32+
outlier_cols: Optional[torch.Tensor] = None,
33+
bias: Optional[torch.Tensor] = None,
34+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35+
subB = None
36+
37+
if outlier_cols is not None and outlier_cols.numel():
38+
# Extract the inputs with outliers in original precision
39+
subA = A[:, outlier_cols].contiguous()
40+
41+
# Dequantize the corresponding weight columns
42+
subB = (
43+
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
44+
.to(A.dtype)
45+
.t()
46+
)
47+
48+
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
49+
50+
else:
51+
# Needed for torch.compile when there are no outliers.
52+
subA = torch.empty(0, device=A.device, dtype=A.dtype)
53+
54+
# Int8 Matmul + Dequant + Bias
55+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
56+
57+
if subB is not None:
58+
# Add the outlier columns back to the output
59+
output = output.addmm(subA, subB)
60+
61+
return output, subA
62+
63+
2564
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2665
A, B = B, A
2766

@@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0):
143182

144183
if outliers.any():
145184
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
185+
else:
186+
# Needed for torch.compile support.
187+
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
146188

147189
with _cuda_device_of(A):
148190
lib.cint8_vector_quant(

Diff for: bitsandbytes/cuda_specs.py

+42-13
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]:
2121

2222

2323
@lru_cache(None)
24-
def get_cuda_version_tuple() -> tuple[int, int]:
25-
if torch.version.cuda:
26-
return tuple(map(int, torch.version.cuda.split(".")[0:2]))
27-
elif torch.version.hip:
28-
return tuple(map(int, torch.version.hip.split(".")[0:2]))
24+
def get_cuda_version_tuple() -> Optional[tuple[int, int]]:
25+
"""Get CUDA/HIP version as a tuple of (major, minor)."""
26+
try:
27+
if torch.version.cuda:
28+
version_str = torch.version.cuda
29+
elif torch.version.hip:
30+
version_str = torch.version.hip
31+
else:
32+
return None
2933

30-
return None
34+
parts = version_str.split(".")
35+
if len(parts) >= 2:
36+
return tuple(map(int, parts[:2]))
37+
return None
38+
except (AttributeError, ValueError, IndexError):
39+
return None
3140

3241

33-
def get_cuda_version_string() -> str:
34-
major, minor = get_cuda_version_tuple()
42+
def get_cuda_version_string() -> Optional[str]:
43+
"""Get CUDA/HIP version as a string."""
44+
version_tuple = get_cuda_version_tuple()
45+
if version_tuple is None:
46+
return None
47+
major, minor = version_tuple
3548
return f"{major * 10 + minor}"
3649

3750

3851
def get_cuda_specs() -> Optional[CUDASpecs]:
52+
"""Get CUDA/HIP specifications."""
3953
if not torch.cuda.is_available():
4054
return None
4155

42-
return CUDASpecs(
43-
highest_compute_capability=(get_compute_capabilities()[-1]),
44-
cuda_version_string=(get_cuda_version_string()),
45-
cuda_version_tuple=get_cuda_version_tuple(),
46-
)
56+
try:
57+
compute_capabilities = get_compute_capabilities()
58+
if not compute_capabilities:
59+
return None
60+
61+
version_tuple = get_cuda_version_tuple()
62+
if version_tuple is None:
63+
return None
64+
65+
version_string = get_cuda_version_string()
66+
if version_string is None:
67+
return None
68+
69+
return CUDASpecs(
70+
highest_compute_capability=compute_capabilities[-1],
71+
cuda_version_string=version_string,
72+
cuda_version_tuple=version_tuple,
73+
)
74+
except Exception:
75+
return None

0 commit comments

Comments
 (0)