diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 794b0aad9..b8dc5a5e1 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -20,15 +20,15 @@ from .optim import adam # This is a signal for integrations with transformers/diffusers. -# Eventually, we will remove this and check based on release version. +# Eventually we may remove this but it is currently required for compatibility. features = {"multi-backend"} supported_torch_devices = { - "cuda", "cpu", - # "mps", - # "xpu", - # "hpu", - # "npu", + "cuda", # NVIDIA/AMD GPU + "xpu", # Intel GPU + "hpu", # Gaudi + "npu", # Ascend NPU + "mps", # Apple Silicon } if torch.cuda.is_available(): diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7fa846d92..85db6366b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -284,7 +284,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor dtype=torch.float16, ) - if state.threshold > 0.0 and subA is not None: + if state.threshold > 0.0 and subA is not None and subA.numel() > 0: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index cf3dd3342..c9341230f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) for i in range(gap): values.append(0) values.sort() - code = torch.Tensor(values) + code = torch.tensor(values) code /= code.max() return code diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index dfa688abb..ea5451502 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -306,9 +306,15 @@ def _quantize(self, device): self.bnb_quantized = True return self + def cpu(self): + return self.to(device="cpu") + def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + return self.to(device="xpu" if device is None else device, non_blocking=non_blocking) + @overload def to( self: T, @@ -326,7 +332,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if device is not None and device.type != "meta" and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: diff --git a/pyproject.toml b/pyproject.toml index 528feac2a..af4c8c240 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,12 @@ include = ["bitsandbytes*"] [tool.setuptools.dynamic] version = {attr = "bitsandbytes.__version__"} +[tool.coverage.report] +exclude_also = [ + # exclude backward() functions from coverage, as they are invoked from C++ + 'def backward\(ctx' +] + [tool.pytest.ini_options] addopts = "-rP -m 'not slow and not benchmark and not deprecated'" # ; --cov=bitsandbytes diff --git a/tests/helpers.py b/tests/helpers.py index 9e85eba93..fbc4af071 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,7 @@ +import functools from io import BytesIO from itertools import product +import os import random from typing import Any @@ -13,6 +15,38 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) +@functools.cache +def get_available_devices(): + if "BNB_TEST_DEVICE" in os.environ: + # If the environment variable is set, use it directly. + return [os.environ["BNB_TEST_DEVICE"]] + + devices = ["cpu"] + + if hasattr(torch, "accelerator"): + # PyTorch 2.6+ - determine accelerator using agnostic API. + if torch.accelerator.is_available(): + devices += [str(torch.accelerator.current_accelerator())] + else: + if torch.cuda.is_available(): + devices += ["cuda"] + + if torch.backends.mps.is_available(): + devices += ["mps"] + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + devices += ["xpu"] + + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_backend_module = getattr(torch, custom_backend_name, None) + custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None) + + if custom_backend_is_available_fn and custom_backend_module.is_available(): + devices += [custom_backend_name] + + return devices + + def torch_save_to_buffer(obj): buffer = BytesIO() torch.save(obj, buffer) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 347a93131..7c43cab80 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -6,12 +6,14 @@ BOOLEAN_TRIPLES, TRUE_FALSE, describe_dtype, + get_available_devices, id_formatter, ) TRANSPOSE_VALS = [(False, True), (False, False)] +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @@ -27,10 +29,16 @@ @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): +def test_matmullt( + device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias +): + if device != "cuda" and funcs[1] == bnb.research.switchback_bnb: + # TODO: Deprecate/remove? + pytest.skip("switchback_bnb only works on CUDA.") + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device) if has_bias == False: req_grad = list(req_grad) req_grad[2] = False @@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype) target = torch.randn( size=(dim2, dim4), - device="cuda", + device=device, requires_grad=req_grad[1], dtype=dtype, ) bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec if has_fp16_weights: if any(req_grad): out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() + if device == "cuda": + torch.cuda.synchronize() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad @@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec torch.testing.assert_close(gradBias1, gradBias2) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) @@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) def test_matmul_4bit( + device, dim1, dim2, dim3, @@ -159,6 +170,9 @@ def test_matmul_4bit( compress_statistics, quant_type, ): + if device == "cpu" and quant_type == "fp4": + pytest.skip("Only nf4 is supported on CPU") + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: @@ -168,13 +182,13 @@ def test_matmul_4bit( for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype) bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) @@ -204,7 +218,8 @@ def test_matmul_4bit( # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() + if device == "cuda": + torch.cuda.synchronize() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad diff --git a/tests/test_functional.py b/tests/test_functional.py index 77a49b1fd..5b9038288 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -13,6 +13,7 @@ BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, + get_available_devices, get_test_dims, id_formatter, ) @@ -87,15 +88,26 @@ def reset(self): class Test8BitBlockwiseQuantizeFunctional: + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) - def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed): + def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): + if device == "cpu": + # This test is slow on CPU, so avoid atypical use cases. + if nested: + pytest.skip("Not a typical use case.") + if blocksize != 256: + pytest.skip("Only blocksize 256 is the typical one supported on CPU.") + + if dtype != torch.float32: + pytest.xfail(f"CPU implementation currently only supports float32, got {dtype}") + diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() @@ -113,7 +125,7 @@ def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed): diffs = [] code = F.create_dynamic_map(signed=signed) for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.rand(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() @@ -154,21 +166,27 @@ def test_blockwise_cpu_large(self): # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) - def test_few_bit_quant(self, bits, method): + def test_few_bit_quant(self, device, bits, method): + if device == "cpu" and bits != 8: + pytest.skip("CPU implementation only supports 8 bits") + abserrs = [] relerrs = [] code = None if method == "linear": - code = F.create_linear_map(True, total_bits=bits).cuda() + code = F.create_linear_map(True, total_bits=bits).to(device) elif method == "fp8": ebits = math.ceil(bits / 2) pbits = bits - ebits - 1 - code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + code = F.create_fp8_map(True, ebits, pbits, bits).to(device) elif method == "dynamic": - code = F.create_dynamic_map(True, bits - 0, bits).cuda() + code = F.create_dynamic_map(True, bits - 0, bits).to(device) elif method == "quantile": + if device != "cuda": + pytest.xfail("Quantile map only works on CUDA") values = torch.randn(2048, 2048, device="cuda") code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero @@ -178,7 +196,7 @@ def test_few_bit_quant(self, bits, method): # print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): - values = torch.randn(1, 32, device="cuda") + values = torch.randn(1, 32, device=device) values /= values.abs().max() # values[values.abs() < 1e-6] += 1e-5 @@ -189,8 +207,8 @@ def test_few_bit_quant(self, bits, method): q1.append(idx.item()) v1.append(code[idx].item()) - q1 = torch.Tensor(q1).cuda() - v1 = torch.Tensor(v1).cuda() + q1 = torch.tensor(q1, device=device) + v1 = torch.tensor(v1, device=device) q2, S2 = F.quantize_blockwise(values, code=code) v2 = F.dequantize_blockwise(q2, S2) @@ -206,15 +224,20 @@ def test_few_bit_quant(self, bits, method): else: torch.testing.assert_close(q1, q2) - def test_fp8_quant(self): + @pytest.mark.parametrize("device", get_available_devices()) + def test_fp8_quant(self, device): + # TODO + if device == "cpu": + pytest.skip("CPU implementation segfaults") + for e_bits in range(1, 7): p_bits = 7 - e_bits - code = F.create_fp8_map(True, e_bits, p_bits).cuda() + code = F.create_fp8_map(True, e_bits, p_bits).to(device) abserr = [] relerr = [] for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") + A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) @@ -228,7 +251,7 @@ def test_fp8_quant(self): abserr = [] relerr = [] for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") + A1 = torch.rand(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) @@ -242,7 +265,7 @@ def test_fp8_quant(self): abserr = [] relerr = [] for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") + A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) @@ -329,6 +352,7 @@ def mean(xx): } +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestIGEMMFunctional: @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) @@ -532,36 +556,38 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) - def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb): + def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), dtype=torch.int8, device=device) + B = torch.randint(-128, 127, size=(dim4, dim3), dtype=torch.int8, device=device) C1 = torch.matmul(A.float(), B.t().float()) C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) - def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims): + def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): for i in range(k): if dims == 2: - A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() + A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half() elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() - B = torch.randn((dim4, dim3), device="cuda").half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device=device).half() + B = torch.randn((dim4, dim3), device=device).half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -573,19 +599,20 @@ def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) - def test_dequant_mm(self, dim1, dim4, dims, has_bias): + def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): inner = 128 bias = None if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=torch.float16) + bias = torch.randn(dim4, device=device, dtype=torch.float16) for i in range(1): - A = torch.randn(dim1, inner, device="cuda") - B = torch.randn(dim4, inner, device="cuda") + A = torch.randn(dim1, inner, device=device) + B = torch.randn(dim4, inner, device=device) C1 = torch.matmul(A.half(), B.t().half()) if has_bias: C1 += bias @@ -618,6 +645,7 @@ def test_dequant_mm(self, dim1, dim4, dims, has_bias): @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) + @pytest.mark.deprecated def test_colrow_absmax(self, dim1, dim2, dims, threshold): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -654,6 +682,7 @@ def test_colrow_absmax(self, dim1, dim2, dims, threshold): @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) + @pytest.mark.deprecated def test_int8_double_quant(self, dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -686,6 +715,7 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -697,10 +727,10 @@ def test_int8_double_quant(self, dim1, dim2): ) ), ) - def test_integrated_int8_linear_matmul(self, dim1, dim4, inner): + def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): for i in range(k): - A = torch.randn(dim1, inner, device="cuda").half() - B = torch.randn(dim4, inner, device="cuda").half() + A = torch.randn(dim1, inner, device=device).half() + B = torch.randn(dim4, inner, device=device).half() out1 = torch.matmul(A.half(), B.t().half()) @@ -724,12 +754,13 @@ def test_integrated_int8_linear_matmul(self, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) - def test_coo_double_quant(self, dim1, dim2): + def test_coo_double_quant(self, device, dim1, dim2): threshold = 2.00 for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).half() idx = torch.abs(A) >= threshold CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) @@ -743,12 +774,13 @@ def test_coo_double_quant(self, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) - def test_coo_int8_vectorwise_quant(self, dim1, dim2): + def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): threshold = 3.00 for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).half() idx = torch.abs(A) >= threshold CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) @@ -759,6 +791,7 @@ def test_coo_int8_vectorwise_quant(self, dim1, dim2): torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2")) @@ -1025,6 +1058,7 @@ def test_spmm_coo_dequant(self, dim1, dim2, dtype): print("partial matmul", time.time() - t0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSparseTensorFunctional: def test_coo2csr(self): threshold = 1 @@ -1063,11 +1097,12 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) - def test_4bit_quant(self, dtype, quant_type, blocksize): - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + def test_4bit_quant(self, device, dtype, quant_type, blocksize): + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) @@ -1095,13 +1130,14 @@ def test_4bit_quant(self, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) - def test_4bit_compressed_stats(self, quant_type, blocksize): + def test_4bit_compressed_stats(self, device, quant_type, blocksize): errs1 = [] errs2 = [] for i in range(10): - A1 = torch.randn(1024, 1024, device="cuda").half() + A1 = torch.randn(1024, 1024, device=device).half() q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) @@ -1127,6 +1163,7 @@ def test_4bit_compressed_stats(self, quant_type, blocksize): # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ["nf4"]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") @pytest.mark.benchmark def test_bench_4bit_dequant(self, quant_type): blocksize = 256 @@ -1157,6 +1194,7 @@ def test_bench_4bit_dequant(self, quant_type): # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1167,7 +1205,7 @@ def test_bench_4bit_dequant(self, quant_type): ids=describe_dtype, ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) - def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind): + def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): errs1 = [] errs2 = [] errs3 = [] @@ -1180,17 +1218,17 @@ def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, for i in range(100): if kind == "fc1": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + A = torch.randn(1, dim, dtype=dtype, device=device) + B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim) elif kind == "fc2": - A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") - B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + A = torch.randn(1, 4 * dim, dtype=dtype, device=device) + B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim) elif kind == "attn": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + A = torch.randn(1, dim, dtype=dtype, device=device) + B = torch.randn(dim, dim, dtype=dtype, device=device) / math.sqrt(dim) elif kind == "attn_packed": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + A = torch.randn(1, dim, dtype=dtype, device=device) + B = torch.randn(dim * 3, dim, dtype=dtype, device=device) / math.sqrt(dim) qB, state = F.quantize_4bit( B, @@ -1294,18 +1332,19 @@ def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) - def test_gemv_eye_4bit(self, storage_type, dtype, double_quant): + def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) dims = [dim + (64 - (dim % 64)) for dim in dims] # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") - B = torch.eye(dim, dtype=dtype, device="cuda") + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device) + B = torch.eye(dim, dtype=dtype, device=device) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) C3 = torch.matmul(A, B.t()) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 2f094be27..669319298 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -7,7 +7,7 @@ import torch import bitsandbytes as bnb -from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer storage = { "uint8": torch.uint8, @@ -17,15 +17,18 @@ } +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) -@pytest.mark.parametrize("bias", TRUE_FALSE) -@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) -def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) +def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): + if device == "cpu": + pytest.xfail("Dequantization is not yet implemented for CPU") + original_dtype = torch.float16 compute_dtype = None - device = "cuda" layer_shape = (300, 400) linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer @@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora # restoring from state_dict: bias_data2 = sd.pop("bias", None) weight_data2 = sd.pop("weight") - weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) + weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device) # creating new layer with same params: linear_q2 = bnb.nn.Linear4bit( @@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert size_ratio < target_compression, ratio_error_msg -def test_copy_param(): - tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) - param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +def test_copy_param(device, quant_type, blocksize, compress_statistics): + if device == "cpu": + if compress_statistics: + pytest.skip("Currently segfaults on CPU") + if quant_type == "fp4": + pytest.xfail("FP4 not supported on CPU") + + tensor = torch.linspace(1, blocksize, blocksize) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, + blocksize=blocksize, + compress_statistics=compress_statistics, + requires_grad=False, + ).to(device) shallow_copy_param = copy.copy(param) assert param.quant_state is shallow_copy_param.quant_state assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -def test_deepcopy_param(): - tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) - param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): + if device == "cpu": + if compress_statistics: + pytest.skip("Currently segfaults on CPU") + if quant_type == "fp4": + pytest.xfail("FP4 not supported on CPU") + + tensor = torch.linspace(1, blocksize, blocksize) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, + blocksize=blocksize, + compress_statistics=compress_statistics, + requires_grad=False, + ).to(device) dict_keys_before = set(param.__dict__.keys()) copy_param = copy.deepcopy(param) dict_keys_after = set(param.__dict__.keys()) @@ -199,12 +234,27 @@ def test_deepcopy_param(): assert dict_keys_before == dict_keys_copy -def test_params4bit_real_serialization(): - original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) - original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): + if device == "cpu": + if compress_statistics: + pytest.skip("Currently segfaults on CPU") + if quant_type == "fp4": + pytest.xfail("FP4 not supported on CPU") + + original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) + original_param = bnb.nn.Params4bit( + data=original_tensor, + quant_type=quant_type, + blocksize=blocksize, + compress_statistics=compress_statistics, + ) dict_keys_before = set(original_param.__dict__.keys()) - original_param.cuda(0) # move to CUDA to trigger quantization + original_param.to(device) # change device to trigger quantization serialized_param = pickle.dumps(original_param) deserialized_param = pickle.loads(serialized_param) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 3d3faf0d2..53a566cb9 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -11,6 +11,7 @@ from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, + get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer, @@ -19,7 +20,11 @@ # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py -def test_linear_no_igemmlt(): +@pytest.mark.parametrize("device", get_available_devices()) +def test_linear_no_igemmlt(device): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + linear = torch.nn.Linear(1024, 3072) x = torch.randn(3, 1024, dtype=torch.half) linear_custom = Linear8bitLt( @@ -29,6 +34,8 @@ def test_linear_no_igemmlt(): has_fp16_weights=False, threshold=6.0, ) + + # TODO: Remove, this is no longer implemented linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( @@ -37,11 +44,11 @@ def test_linear_no_igemmlt(): has_fp16_weights=False, ).to(linear.weight.dtype) linear_custom.bias = linear.bias - linear_custom = linear_custom.cuda() - linear = linear.half().cuda() + linear_custom = linear_custom.to(device) + linear = linear.half().to(device) - x_ref = x.clone().cuda().requires_grad_(True) - x_ours = x.clone().cuda().requires_grad_(True) + x_ref = x.clone().to(device).requires_grad_(True) + x_ours = x.clone().to(device).requires_grad_(True) fx_ref = linear(x_ref).float() grad_proj = torch.randn_like(fx_ref) (fx_ref * grad_proj).mean().backward() @@ -58,18 +65,25 @@ def test_linear_no_igemmlt(): torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) def test_linear_serialization( + device, has_fp16_weights, + threshold, serialize_before_forward, deserialize_before_cuda, save_before_forward, load_before_cuda, ): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + linear = torch.nn.Linear(32, 96) # TODO: Fallback for bad shapes x = torch.randn(4, 32, dtype=torch.half) @@ -80,7 +94,7 @@ def test_linear_serialization( linear.out_features, linear.bias is not None, has_fp16_weights=has_fp16_weights, - threshold=6.0, + threshold=threshold, ) linear_custom.weight = bnb.nn.Int8Params( @@ -89,7 +103,7 @@ def test_linear_serialization( has_fp16_weights=has_fp16_weights, ) linear_custom.bias = linear.bias - linear_custom = linear_custom.cuda() + linear_custom = linear_custom.to(device) if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() @@ -125,7 +139,7 @@ def test_linear_serialization( linear.out_features, linear.bias is not None, has_fp16_weights=has_fp16_weights, - threshold=6.0, + threshold=threshold, ) if deserialize_before_cuda: @@ -135,7 +149,7 @@ def test_linear_serialization( if load_before_cuda: new_linear_custom2 = torch_load_from_buffer(bytes_8bit) - new_linear_custom = new_linear_custom.cuda() + new_linear_custom = new_linear_custom.to(device) if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) diff --git a/tests/test_modules.py b/tests/test_modules.py index c2583550d..8ef0890ec 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,13 +1,11 @@ import inspect -import math -import einops import pytest import torch from torch import nn import bitsandbytes as bnb -from tests.helpers import id_formatter +from tests.helpers import get_available_devices, id_formatter class MockArgs: @@ -54,266 +52,32 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -class LinearFunction(torch.autograd.Function): - @staticmethod - def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round - norm = math.sqrt(math.pi) / math.sqrt(2.0) - # std = torch.abs(x).mean()*norm - std = torch.std(x) - max1 = std * trim_value - x = x / max1 * 127 - x = round_func(x) - x[x > 127] = 127 - x[x < -127] = -127 - x = x / 127 * max1 - - return x - - def quant(x, quant_type, dim=1): - if quant_type == "linear": - max1 = torch.abs(x).max().float() - xq = torch.round(x / max1 * 127).to(torch.int8) - return xq, max1 - elif quant_type == "vector": - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x / max1 * 127).to(torch.int8) - return xq, max1 - elif quant_type == "min-max": - maxA = torch.amax(x, dim=dim, keepdim=True).float() - minA = torch.amin(x, dim=dim, keepdim=True).float() - scale = (maxA - minA) / 2.0 - xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8) - return xq, (minA.float(), scale.float()) - else: - return None - - def dequant(xq, S1, S2, dtype, quant_type): - if quant_type == "linear": - norm = S1 * S2 / (127 * 127) - # double cast needed to prevent overflows - return (xq.float() * norm).to(dtype) - elif quant_type == "vector": - x = xq.float() - if len(xq.shape) == 2 and len(S1.shape) == 3: - S1 = S1.squeeze(0) - if len(xq.shape) == 2 and len(S2.shape) == 3: - S2 = S2.squeeze(0) - # print(x.shape, S1.shape, S2.shape) - if len(S1.shape) == 2: - x *= S1.t() / 127 - else: - x *= S1 / 127 - x *= S2 / 127 - return x.to(dtype) - else: - return None - - def dequant_min_max(xq, A, B, SA, SB, dtype): - offset = B.float().t().sum(0) * (SA[0] + SA[1]) - x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: - SB = SB.squeeze(0) - if len(xq.shape) == 2 and len(SA.shape) == 3: - SA = SA.squeeze(0) - if len(SB.shape) == 2: - x *= SB.t() / 127 - else: - x *= SB / 127 - x *= SA[1] / 127 - x += offset - return x.to(dtype) - - def get_8bit_linear(x, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round - max1 = torch.abs(x).max() - x = x / max1 * 127 - x = round_func(x) / 127 * max1 - # x = torch.round(x)/128*max1 - return x - - @staticmethod - def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - max1[max1 == 0] = 1.0 - x = (x * 127) / max1 - x = round_func(x) / 127 * max1 - return x - - @staticmethod - def round_stoachastic(x): - sign = torch.sign(x) - absx = torch.abs(x) - decimal = absx - torch.floor(absx) - rdm = torch.rand_like(decimal) - return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype)) - - @staticmethod - def fake_8bit_storage(w, exponent_bits): - code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device) - absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) - out = bnb.functional.dequantize_blockwise(absmax, C, code) - out = out.half() - w.copy_(out) - return out - - @staticmethod - def fake_8bit_storage_quantile(w, args): - code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) - # C = bnb.functional.quantize_no_absmax(code, w) - # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) - # print(out) - # out = out.half() - code /= torch.max(torch.abs(code)) - absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) - out = bnb.functional.dequantize_blockwise(absmax, C, code) - out = out.half() - w.copy_(out) - return out - - @staticmethod - def fake_8bit_storage_stoachstic(w): - rand = torch.rand(1024, device=w.device) - absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand) - out = bnb.functional.dequantize_blockwise(absmax, C) - out = out.half() - w.copy_(out) - return out - - @staticmethod - def fake_8bit_storage_with_max(w, topk=8): - blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256) - max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) - idx = idx[:, :topk] - max_val = max_val[:, :topk] - - mask = torch.zeros_like(blocked_w) - mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val)) - mask = mask.bool() - - # 1. zero out max values - # 2. quantize + dequantize - # 3. write back max values - # 4. copy matrix back to weight - - values = blocked_w[mask] - blocked_w[mask] = 0 - - code = bnb.functional.create_dynamic_map() - code = code.to(w.device) - absmax, C = bnb.functional.quantize_blockwise(blocked_w.data) - bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w) - - blocked_w[mask] = values - - unblocked_w = blocked_w.flatten().view(w.shape) - - w.copy_(unblocked_w) - return unblocked_w - - @staticmethod - def forward(ctx, x, weight, bias=None, args=None): - if args.use_8bit_training != "off": - weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) - x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) - outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) - # if torch.rand(1) < 0.01: - # output32 = torch.matmul(x, weight.t()) - # err = torch.abs(output-output32).float() - # relerr = err/(torch.abs(output32).float()+1e-8) - # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) - else: - # output = torch.matmul(x, weight.t()) - output = torch.einsum("bsi,oi->bso", x, weight) - - ctx.save_for_backward(x, weight, bias) - ctx.args = args - - if bias is not None: - output += bias.unsqueeze(0).expand_as(output) - return output - - @staticmethod - def backward(ctx, grad_output): - x, weight, bias = ctx.saved_tensors - args = ctx.args - stochastic = False - grad_input = grad_weight = grad_bias = None - if bias is not None and ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - # weight and x are already 8bit - # -> transform grad_output to 8-bit - if args.use_8bit_training == "forward+wgrad": - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) - x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) - grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) - - # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) - - grad_input = grad_output.matmul(weight) - elif args.use_8bit_training == "full": - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) - x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) - grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) - bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) - - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) - weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) - grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) - - else: - grad_input = grad_output.matmul(weight) - grad_weight = torch.einsum("bsi,bso->oi", x, grad_output) - - return grad_input, grad_weight, grad_bias, None - - -class Linear8bit(nn.Module): - def __init__(self, input_features, output_features, bias=True, args=None): - super().__init__() - self.input_features = input_features - self.output_features = output_features - self.args = args - - self.weight = nn.Parameter(torch.empty(output_features, input_features)) - if bias: - self.bias = nn.Parameter(torch.empty(output_features)) - else: - self.register_parameter("bias", None) - - torch.nn.init.xavier_uniform_(self.weight) - if self.bias is not None: - torch.nn.init.zeros_(self.bias) - - def forward(self, x): - self.args.training = self.training - - return LinearFunction.apply(x, self.weight, self.bias, self.args) - - +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) -def test_linear8bitlt_inference(threshold): - l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() - assert l1.weight.device.type == "cuda" - assert l1.weight.dtype == torch.float16 +def test_linear8bitlt_inference(device, threshold): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + + l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() + assert l1.weight.device.type == device + assert l1.weight.dtype == torch.int8 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device).half() o1 = l1(b1) if i == 1: assert l1.state.CB is not None -def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) - l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) +# TODO: Remove support for training int8 weights +@pytest.mark.parametrize("device", get_available_devices()) +def test_linear8bitlt_accumulated_gradient(device): + if device != "cuda": + pytest.skip("Only supported on CUDA") + + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).to(device).half() for i in range(2)]) l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) l1[0].bias.data.copy_(l2[0].bias.data) @@ -325,7 +89,7 @@ def test_linear8bitlt_accumulated_gradient(): acc_steps = 10 for i in range(15): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device).half() o1 = l1(b1) o2 = l2(b1) loss1 = o1.mean() @@ -353,8 +117,12 @@ def test_linear8bitlt_accumulated_gradient(): assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 2.0]) -def test_linear8bitlt_no_fp16_weights(threshold): +def test_linear8bitlt_no_fp16_weights(device, threshold): + if device == "cpu": + pytest.xfail("Not yet supported on CPU") + l1 = ( bnb.nn.Linear8bitLt( 32, @@ -362,23 +130,23 @@ def test_linear8bitlt_no_fp16_weights(threshold): threshold=threshold, has_fp16_weights=False, ) - .cuda() + .to(device) .half() ) assert l1.weight.dtype == torch.int8 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = l1(b1) assert o1.dtype == torch.float16 - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device) assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 if threshold > 0: @@ -386,12 +154,12 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 if threshold > 0: @@ -399,10 +167,10 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 if threshold > 0: @@ -420,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): has_fp16_weights=False, ) .half() - .to("cuda") + .to(device) ) for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 if threshold > 0: @@ -433,8 +201,8 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == "cuda" - assert mlp.fc2.weight.device.type == "cuda" + assert mlp.fc1.weight.device.type == device + assert mlp.fc2.weight.device.type == device mlp = MLP8bit( 32, @@ -442,11 +210,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): threshold=threshold, has_fp16_weights=False, ) - w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, + w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 if threshold > 0: @@ -456,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == "cuda" - assert mlp.fc2.weight.device.type == "cuda" + assert mlp.fc1.weight.device.type == device + assert mlp.fc2.weight.device.type == device - b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + b1 = torch.randn(16, 8, 32, device=device, requires_grad=True, dtype=torch.half) o1 = mlp(b1) assert o1.dtype == torch.float16 assert o1.requires_grad @@ -475,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert (idx == 0).sum().item() <= b1.numel() * 0.005 +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( "module", [ lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), - bnb.nn.LinearFP4, + bnb.nn.LinearNF4, ], - ids=["Int8Lt", "FP4"], + ids=["Int8Lt", "NF4"], ) -def test_linear_kbit_fp32_bias(module): +def test_linear_kbit_fp32_bias(device, module): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + # casts model to fp16 -> int8 automatically - l1 = module(32, 64).cuda() + l1 = module(32, 64).to(device) assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias.dtype == torch.float32 for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) # casts bias to fp32 o1 = l1(b1) assert l1.bias.dtype == torch.float16 # casts model to fp16 -> int8 automatically - l1 = module(32, 64, bias=False).cuda() + l1 = module(32, 64, bias=False).to(device) assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias is None for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() + b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = l1(b1) assert l1.bias is None @@ -519,8 +291,12 @@ def test_linear_kbit_fp32_bias(module): } +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) -def test_kbit_backprop(module): +def test_kbit_backprop(device, module): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + b = 16 dim1 = 36 dim2 = 84 @@ -536,16 +312,16 @@ def test_kbit_backprop(module): kbit[1].weight.detach().copy_(ref[1].weight) kbit[0].bias.detach().copy_(ref[0].bias) kbit[1].bias.detach().copy_(ref[1].bias) - ref = ref.half().cuda() - kbit = kbit.half().cuda() - kbit = kbit.half().to("cuda") + ref = ref.half().to(device) + kbit = kbit.half().to(device) + kbit = kbit.half().to(device) errs1 = [] errs2 = [] relerrs1 = [] relerrs2 = [] for i in range(100): - batch = torch.randn(b, dim1).half().cuda() + batch = torch.randn(b, dim1, device=device, dtype=torch.float16) out1 = ref(batch) out2 = kbit(batch) out1.mean().backward() @@ -578,6 +354,7 @@ def test_kbit_backprop(module): assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 +@pytest.mark.deprecated def test_fp8linear(): b = 10 h = 1024 @@ -608,6 +385,7 @@ def test_fp8linear(): assert bgraderr < 0.00002 +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) @pytest.mark.parametrize( @@ -621,7 +399,10 @@ def test_fp8linear(): ], ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ) -def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage): +def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage): + if device == "cpu": + pytest.xfail("Not yet supported on CPU") + num_embeddings = 128 src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( @@ -641,10 +422,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s e.load_state_dict(emb_base.state_dict()) - emb_base.cuda() - e.cuda() + emb_base.to(device) + e.to(device) - input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device) torch.testing.assert_close( actual=e(input_tokens), @@ -652,6 +433,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s ) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) @pytest.mark.parametrize( @@ -665,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s ], ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ) -def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage): +def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage): + if device == "cpu": + pytest.xfail("Not yet supported on CPU") + is_8bit = embedding_class is bnb.nn.Embedding8bit num_embeddings = 128 @@ -685,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor e.load_state_dict(emb_base.state_dict()) - emb_base.cuda() - e.cuda() + emb_base.to(device) + e.to(device) - input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device) torch.testing.assert_close( actual=e(input_tokens), @@ -698,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor ) -def test_4bit_linear_warnings(): +@pytest.mark.parametrize("device", get_available_devices()) +def test_4bit_linear_warnings(device): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + dim1 = 64 with pytest.warns(UserWarning, match=r"inference or training"): - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) - net = net.cuda() - inp = torch.rand(10, dim1).cuda().half() + net = nn.Sequential( + *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)] + ) + net = net.to(device) + inp = torch.rand(10, dim1, device=device, dtype=torch.float16) net(inp) with pytest.warns(UserWarning, match=r"inference."): - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) - net = net.cuda() - inp = torch.rand(1, dim1).cuda().half() + net = nn.Sequential( + *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)] + ) + net = net.to(device) + inp = torch.rand(1, dim1, device=device, dtype=torch.float16) net(inp) with pytest.warns(UserWarning) as record: - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) - net = net.cuda() - inp = torch.rand(10, dim1).cuda().half() + net = nn.Sequential( + *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)] + ) + net = net.to(device) + inp = torch.rand(10, dim1, device=device, dtype=torch.float16) net(inp) - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) - net = net.cuda() - inp = torch.rand(1, dim1).cuda().half() + net = nn.Sequential( + *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)] + ) + net = net.to(device) + inp = torch.rand(1, dim1, device=device, dtype=torch.float16) net(inp) assert len(record) == 2 -def test_4bit_embedding_warnings(): +@pytest.mark.parametrize("device", get_available_devices()) +def test_4bit_embedding_warnings(device): + if device == "cpu": + pytest.xfail("Not yet implemented on CPU") + num_embeddings = 128 default_block_size = 64 with pytest.warns(UserWarning, match=r"inference."): - net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1) - net.cuda() - inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + net = bnb.nn.Embedding4bit( + num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4" + ) + net.to(device) + inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device) net(inp) -def test_4bit_embedding_weight_fsdp_fix(): +def test_4bit_embedding_weight_fsdp_fix(requires_cuda): num_embeddings = 64 embedding_dim = 32 @@ -754,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix(): assert module.weight.quant_state is not None -def test_4bit_linear_weight_fsdp_fix(): +def test_4bit_linear_weight_fsdp_fix(requires_cuda): inp_size = 64 out_size = 32 diff --git a/tests/test_ops.py b/tests/test_ops.py index 8c9c6a646..9869f51ef 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,11 +4,11 @@ import torch import bitsandbytes -from tests.helpers import TRUE_FALSE, id_formatter +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter class TestLLMInt8Ops: - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -20,7 +20,7 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul_out(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -35,7 +35,7 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): if device == "cpu": pytest.skip("CPU implementation is not available") @@ -64,7 +64,7 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_mm_dequant(self, device): A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -77,7 +77,7 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): @@ -96,7 +96,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): @@ -116,7 +116,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): @@ -140,7 +140,7 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -164,7 +164,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -197,7 +197,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) diff --git a/tests/test_optim.py b/tests/test_optim.py index 2bc3752f3..9358a2e9b 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -47,7 +47,6 @@ def rm_path(path): ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) -str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) @@ -88,19 +87,14 @@ def rm_path(path): ) str2optimizers["lion"] = (Lion, bnb.optim.Lion) -str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False)) -str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) +str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["momentum8bit"] = ( - lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), -) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -110,10 +104,6 @@ def rm_path(path): lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["rmsprop8bit"] = ( - lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), -) str2optimizers["rmsprop8bit_blockwise"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), @@ -128,8 +118,7 @@ def rm_path(path): str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] -str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] -str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] + str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), @@ -142,10 +131,8 @@ def rm_path(path): ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] -str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] -str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] + str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] -str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] @@ -180,7 +167,7 @@ def rm_path(path): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) -def test_optimizer32bit(dim1, dim2, gtype, optim_name): +def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: @@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) -def test_global_config(dim1, dim2, gtype): +def test_global_config(requires_cuda, dim1, dim2, gtype): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 @@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype): optimizer_names_8bit = [ - "adam8bit", - "lion8bit", - "momentum8bit", - "rmsprop8bit", + # Non-blockwise optimizers are deprecated. + # "adam8bit", + # "lion8bit", + # "momentum8bit", + # "rmsprop8bit", "adam8bit_blockwise", "lion8bit_blockwise", "momentum8bit_blockwise", @@ -315,7 +303,7 @@ def test_global_config(dim1, dim2, gtype): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -def test_optimizer8bit(dim1, dim2, gtype, optim_name): +def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): torch.set_printoptions(precision=6) if gtype == torch.bfloat16 and "blockwise" not in optim_name: @@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): +@pytest.mark.deprecated +def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1