Skip to content

Updates for device agnosticism #1601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from .optim import adam

# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
"cuda", # NVIDIA/AMD GPU
"xpu", # Intel GPU
"hpu", # Gaudi
"npu", # Ascend NPU
"mps", # Apple Silicon
}

if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
dtype=torch.float16,
)

if state.threshold > 0.0 and subA is not None:
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code = torch.tensor(values)
code /= code.max()

return code
Expand Down
8 changes: 7 additions & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,15 @@ def _quantize(self, device):
self.bnb_quantized = True
return self

def cpu(self):
return self.to(device="cpu")

def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)

@overload
def to(
self: T,
Expand All @@ -326,7 +332,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type == "cuda" and not self.bnb_quantized:
if device is not None and device.type != "meta" and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
[tool.setuptools.dynamic]
version = {attr = "bitsandbytes.__version__"}

[tool.coverage.report]
exclude_also = [
# exclude backward() functions from coverage, as they are invoked from C++
'def backward\(ctx'
]

[tool.pytest.ini_options]
addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes
Expand Down
34 changes: 34 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
from io import BytesIO
from itertools import product
import os
import random
from typing import Any

Expand All @@ -13,6 +15,38 @@
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)


@functools.cache
def get_available_devices():
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]

devices = ["cpu"]

if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
if torch.accelerator.is_available():
devices += [str(torch.accelerator.current_accelerator())]
else:
if torch.cuda.is_available():
devices += ["cuda"]

if torch.backends.mps.is_available():
devices += ["mps"]

if hasattr(torch, "xpu") and torch.xpu.is_available():
devices += ["xpu"]

custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_backend_module = getattr(torch, custom_backend_name, None)
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)

if custom_backend_is_available_fn and custom_backend_module.is_available():
devices += [custom_backend_name]

return devices


def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
Expand Down
39 changes: 27 additions & 12 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
BOOLEAN_TRIPLES,
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
)

TRANSPOSE_VALS = [(False, True), (False, False)]


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
Expand All @@ -27,32 +29,38 @@
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")

dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False

for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
device=device,
requires_grad=req_grad[1],
dtype=dtype,
)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
Expand Down Expand Up @@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
Expand Down Expand Up @@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
Expand All @@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
device,
dim1,
dim2,
dim3,
Expand All @@ -159,6 +170,9 @@ def test_matmul_4bit(
compress_statistics,
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")

dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
Expand All @@ -168,13 +182,13 @@ def test_matmul_4bit(
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)

Expand Down Expand Up @@ -204,7 +218,8 @@ def test_matmul_4bit(
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
Expand Down
Loading