From 35f056dfb19c51001381781efddecffe8b007a46 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Mon, 19 Feb 2024 10:11:00 +0000 Subject: [PATCH 01/18] remove dead code --- bitsandbytes/cextension.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 858365f02..b5360f48b 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -32,8 +32,3 @@ "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False print(str(ex)) - - -# print the setup details after checking for errors so we do not print twice -#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - #setup.print_log_stack() From c46fcf69db6987da18e318dacc725844c7691ea0 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:26:15 +0000 Subject: [PATCH 02/18] upgrade pre-commit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edcbc9b6b..c8ccfe8df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,6 @@ repos: args: - --fix=lf - repo: https://github.com/crate-ci/typos - rev: v1.17.2 + rev: v1.18.2 hooks: - id: typos From 8a29bc58e04d2b93ae1431c0e87d721911921ef8 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:38:30 +0000 Subject: [PATCH 03/18] reshuffle settings to make ruff happy --- pyproject.toml | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f74750720..e5574b152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,16 +8,10 @@ src = [ "tests", "benchmarking" ] -select = [ - "B", # bugbear: security warnings - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "ISC", # implicit string concatenation - "UP", # alert you when better syntax is available in your python version - "RUF", # the ruff developer's own rules -] + target-version = "py38" + +[tool.ruff.lint] ignore = [ "B007", # Loop control variable not used within the loop body (TODO: enable) "B028", # Warning without stacklevel (TODO: enable) @@ -29,8 +23,17 @@ ignore = [ "RUF012", # Mutable class attribute annotations ] ignore-init-module-imports = true # allow to expose in __init__.py via imports +select = [ + "B", # bugbear: security warnings + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "ISC", # implicit string concatenation + "UP", # alert you when better syntax is available in your python version + "RUF", # the ruff developer's own rules +] -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "**/__init__.py" = ["F401"] # allow unused imports in __init__.py "{benchmarking,tests}/**/*.py" = [ "B007", @@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports "UP030", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true From 8415b6e6178ed92f132db589fac6a58c54bf539a Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:41:25 +0000 Subject: [PATCH 04/18] ignore ISC001 to avoid formatting conflicts --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e5574b152..8e95f06da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ ignore = [ "E731", # Do not use lambda "F841", # Local assigned but not used (TODO: enable, these are likely bugs) "RUF012", # Mutable class attribute annotations + "ISC001", # String concatination warning: may cause conflicts when used with the formatter ] ignore-init-module-imports = true # allow to expose in __init__.py via imports select = [ From a162d40520ff0da376f340dd11f6e9c983f1fbf1 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:41:52 +0000 Subject: [PATCH 05/18] whitespace for alignment --- pyproject.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e95f06da..9a2072af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,13 +13,13 @@ target-version = "py38" [tool.ruff.lint] ignore = [ - "B007", # Loop control variable not used within the loop body (TODO: enable) - "B028", # Warning without stacklevel (TODO: enable) - "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. - "E701", # Multiple statements on one line (TODO: enable) - "E712", # Allow using if x == False, as it's not always equivalent to if x. - "E731", # Do not use lambda - "F841", # Local assigned but not used (TODO: enable, these are likely bugs) + "B007", # Loop control variable not used within the loop body (TODO: enable) + "B028", # Warning without stacklevel (TODO: enable) + "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. + "E701", # Multiple statements on one line (TODO: enable) + "E712", # Allow using if x == False, as it's not always equivalent to if x. + "E731", # Do not use lambda + "F841", # Local assigned but not used (TODO: enable, these are likely bugs) "RUF012", # Mutable class attribute annotations "ISC001", # String concatination warning: may cause conflicts when used with the formatter ] From 7ffe552173367c1ee41de5c83a4907b9fae97694 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:43:11 +0000 Subject: [PATCH 06/18] ruff format cextension.py --- bitsandbytes/cextension.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index b5360f48b..7dd7ccfda 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -14,21 +14,27 @@ if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() - raise RuntimeError(''' + raise RuntimeError( + """ CUDA Setup failed despite GPU being available. Please run the following command to get more information: python -m bitsandbytes Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes - and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues""" + ) + _ = ( + lib.cadam32bit_grad_fp32 + ) # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError as ex: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") + warn( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." + ) COMPILED_WITH_CUDA = False print(str(ex)) From 9b45576b329d7bd55282451335383405d93e8d6c Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:05:35 +0000 Subject: [PATCH 07/18] ignore formatting commits --- .git-blame-ignore-revs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f7dd01bdf..646c4665e 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -6,3 +6,9 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # Remove f-prefix from strings that don't use formatting 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 + +# format bitsandbytes/cextension.py +04f691ef3061e6659aa0a741ca97d00d031618c4 + +# whitespace in pyproject.toml +f7b791863083429ba79dc00f925a041beab63297 From c2a8594d73d894e6ed8ab59e2cc1c037d855a3e4 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:32:57 +0000 Subject: [PATCH 08/18] relocate cuda_setup module ito device_setup/cuda --- bitsandbytes/__init__.py | 2 +- bitsandbytes/__main__.py | 2 +- bitsandbytes/cextension.py | 2 +- bitsandbytes/{cuda_setup => device_setup/cuda}/__init__.py | 0 bitsandbytes/{cuda_setup => device_setup/cuda}/env_vars.py | 0 bitsandbytes/{cuda_setup => device_setup/cuda}/main.py | 4 ++-- tests/test_cuda_setup_evaluator.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename bitsandbytes/{cuda_setup => device_setup/cuda}/__init__.py (100%) rename bitsandbytes/{cuda_setup => device_setup/cuda}/env_vars.py (100%) rename bitsandbytes/{cuda_setup => device_setup/cuda}/main.py (99%) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index e54e933d9..3f175319a 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, research, utils +from . import research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 61b42e78f..d8ba54500 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -59,7 +59,7 @@ def main(): generate_bug_report_information() from . import COMPILED_WITH_CUDA - from .cuda_setup.main import get_compute_capabilities + from .device_setup.cuda.main import get_compute_capabilities print_header("OTHER") print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7dd7ccfda..171572ab9 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -3,7 +3,7 @@ import torch -from bitsandbytes.cuda_setup.main import CUDASetup +from bitsandbytes.device_setup.cuda.main import CUDASetup setup = CUDASetup.get_instance() if setup.initialized != True: diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/device_setup/cuda/__init__.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/device_setup/cuda/__init__.py diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/device_setup/cuda/env_vars.py similarity index 100% rename from bitsandbytes/cuda_setup/env_vars.py rename to bitsandbytes/device_setup/cuda/env_vars.py diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/device_setup/cuda/main.py similarity index 99% rename from bitsandbytes/cuda_setup/main.py rename to bitsandbytes/device_setup/cuda/main.py index cd0d94cd7..a37cbf36a 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/device_setup/cuda/main.py @@ -142,7 +142,7 @@ def run_cuda_setup(self): self.binary_name = binary_name self.manual_override() - package_dir = Path(__file__).parent.parent + package_dir = Path(__file__).parent.parent.parent binary_path = package_dir / self.binary_name try: @@ -278,7 +278,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: 1. active conda env 2. LD_LIBRARY_PATH 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - are known to be unrelated (see `bnb.device_setup.cuda.env_vars.to_be_ignored`) - don't contain the path separator `/` If multiple libraries are found in part 3, we optimistically try one, diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 189aa75b5..e3620bf41 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -17,5 +17,5 @@ def test_manual_override(requires_cuda): os.environ['BNB_CUDA_VERSION']='122' #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] import bitsandbytes as bnb - loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name + loaded_lib = bnb.device_setup.cuda.main.CUDASetup.get_instance().binary_name #assert loaded_lib == 'libbitsandbytes_cuda122.so' From 3a57942122a56cc9a14b7ad0462b011a6d6a06cb Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:13:33 +0000 Subject: [PATCH 09/18] ruff format functional.py --- bitsandbytes/functional.py | 1030 +++++++++++++++++++++++++++--------- 1 file changed, 768 insertions(+), 262 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f0de962e1..9cb7abc39 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -21,6 +21,7 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) + name2qmap = {} if COMPILED_WITH_CUDA: @@ -127,7 +128,6 @@ def prefetch_all(self, to_cpu=False): prefetch_tensor(t, to_cpu) - class CUBLAS_Context: _instance = None @@ -169,6 +169,7 @@ def get_instance(cls): cls._instance.initialize() return cls._instance + dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -176,10 +177,11 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -FIRST_CUDA_DEVICE = torch.device('cuda', index=0) +FIRST_CUDA_DEVICE = torch.device("cuda", index=0) + def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): - num_bytes = dtype2bytes[dtype]*prod(shape) + num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): out.page_deviceid = device.index return out + def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' + assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype]*A.numel() + num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) + func = getattr(lib, f"c{func_name}_fp32", None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) + func = getattr(lib, f"c{func_name}_uint8", None) cvalue = ct.c_uint8(value) - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + if func is None: + raise NotImplementedError(f"Function not implemented: {func_name}") - is_managed = getattr(A, 'is_managed', False) + is_managed = getattr(A, "is_managed", False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: prefetch_tensor(B) + if B is not None: + prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: @@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # operation occurred. So we synchronize. torch.cuda.synchronize() -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + +def fill(A, value, device=None, prefetch=True): + elementwise_func("fill", A, None, value) + + +def arange(A, device=None): + elementwise_func("arange", A, None, 0) + + +def _mul(A, B, device=None): + elementwise_func("_mul", A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) + sign = -1.0 if signed else 0.0 total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) + total_values = 2**total_bits if not signed else 2**total_bits - 1 values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel()//2 # noqa: E741 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + l = values.numel() // 2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -258,11 +272,11 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -275,38 +289,39 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign + assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate( + range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1) + ): evalues.append(2**val) - values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): for bit_pattern in lst: - value = (1 if evalue != 0 else 0) + value = 1 if evalue != 0 else 0 for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) + value += pval * (2 ** -(i + 1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value * 2**-(bias) else: # normals - value = value*2**-(evalue-bias-1) + value = value * 2 ** -(evalue - bias - 1) values.append(value) if signed: values.append(-value) - assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -320,7 +335,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code - def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -345,7 +359,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -371,8 +389,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) + def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() q.append(0) @@ -383,11 +402,13 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q/q.abs().max() + q = q / q.abs().max() return q + def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' + if not torch.cuda.is_available(): + return "col_turing" major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -396,20 +417,24 @@ def get_special_format_str(): return "col_turing" - def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" + ) if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" + ) return on_gpu @@ -534,8 +559,13 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' +def estimate_quantiles( + A: Tensor, + out: Optional[torch.Tensor] = None, + offset: float = 1 / 512, + num_quantiles=256, +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -562,26 +592,37 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): + """ + if A.numel() < 256: + raise NotImplementedError( + f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values." + ) + if num_quantiles > 256: + raise NotImplementedError( + f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}" + ) + if num_quantiles < 256 and offset == 1 / (512): # override default arguments - offset = 1/(2*num_quantiles) + offset = 1 / (2 * num_quantiles) - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp32( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp16( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) else: raise NotImplementedError(f"Not supported data type {A.dtype}") post_call(device) if num_quantiles < 256: - step = round(256/num_quantiles) + step = round(256 / num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -590,12 +631,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl class QuantState: """container for quantization state components to work with Params4bit and similar classes""" - valid_quant_types = ('fp4', 'nf4') - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', - 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] - def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): self.absmax = absmax self.shape = shape self.code = code @@ -614,13 +678,27 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] else: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + None, + self.quant_type, + ] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -631,38 +709,48 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState """ # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and 'quant_type' not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") + qs_key = [ + k + for k, v in qs_dict.items() + if "quant_state" in k and isinstance(v, torch.Tensor) + ] + if not len(qs_key) and "quant_type" not in qs_dict: + raise ValueError( + "Expected packed or unpacked quant_state items, found neither" + ) elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}." + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: first_qs_key = qs_key[0] qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if 'nested_absmax' in qs_dict: - offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) state2 = cls( - absmax=qs_dict['nested_absmax'].to(device), - blocksize=qs_dict['nested_blocksize'], - code=qs_dict['nested_quant_map'].to(device), - dtype=getattr(torch, qs_dict['nested_dtype']), + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict['quant_type'], - absmax=qs_dict['absmax'].to(device), - blocksize=qs_dict['blocksize'], - code=qs_dict['quant_map'].to(device), - dtype=getattr(torch, qs_dict['dtype']), - shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) + if qs_dict["shape"] is not None + else None, offset=offset, state2=state2, ) @@ -674,28 +762,36 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - 'quant_type': self.quant_type, - 'absmax': self.absmax, - 'blocksize': self.blocksize, - 'quant_map': self.code, - 'dtype': str(self.dtype).strip('torch.'), - 'shape': tuple(self.shape), + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), } if self.nested: - qs_dict.update({ - 'nested_absmax': self.state2.absmax, - 'nested_blocksize': self.state2.blocksize, - 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - 'nested_dtype': str(self.state2.dtype).strip('torch.'), - 'nested_offset': self.offset.item(), - }) + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + } + ) if not packed: return qs_dict # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} - non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + qs_packed_dict = { + k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor) + } + non_tensor_dict = { + k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor) + } + qs_packed_dict[ + "quant_state." + "bitsandbytes__" + self.quant_type + ] = pack_dict_to_tensor(non_tensor_dict) return qs_packed_dict def to(self, device): @@ -756,7 +852,6 @@ def quantize_blockwise( The quantization state to undo the quantization. """ - if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -771,33 +866,72 @@ def quantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != 'cpu': + if A.device.type != "cpu": assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) else: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState( + absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype + ) return out, quant_state @@ -809,7 +943,7 @@ def dequantize_blockwise( code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False + nested=False, ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -843,43 +977,80 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + quant_state = QuantState( + absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32 + ) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != 'cpu': + if A.device.type != "cpu": device = pre_call(A.device) code = quant_state.code.to(A.device) if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + raise ValueError( + f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_bf16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) return out + def get_4bit_type(typename, device=None, blocksize=64): - if device is None: device = 'cuda' + if device is None: + device = "cuda" data = None - if typename == 'nf4': - ''' Implements the NF4 data type. + if typename == "nf4": + """ Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -888,12 +1059,26 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, - -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, - 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0] - elif typename == 'fp4': + """ + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + elif typename == "fp4": # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -903,21 +1088,55 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b110 = 2 # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) - data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] - elif typename == 'int4': + data = [ + 0, + 0.0625, + 8.0, + 12.0, + 4.0, + 6.0, + 2.0, + 3.0, + -0, + -0.0625, + -8.0, + -12.0, + -4.0, + -6.0, + -2.0, + -3.0, + ] + elif typename == "int4": data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == 'af4': + elif typename == "af4": # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, - -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, - 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + data = [ + -1.0, + -0.69441008, + -0.51243739, + -0.3736951, + -0.25607552, + -0.14982478, + -0.04934812, + 0.0, + 0.04273164, + 0.12934483, + 0.21961274, + 0.31675666, + 0.42563882, + 0.55496234, + 0.72424863, + 1.0, + ][::-1] else: - raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError( + "4-bit AbnormalFloats currently only support blocksize 64." + ) if data is None: - raise NotImplementedError(f'Typename {typename} not supported') + raise NotImplementedError(f"Typename {typename} not supported") data = Tensor(data) data /= data.abs().max() @@ -926,11 +1145,30 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) +def quantize_fp4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit( + A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage + ) + -def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) +def quantize_nf4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit( + A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage + ) def quantize_4bit( @@ -939,7 +1177,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type='fp4', + quant_type="fp4", quant_storage=torch.uint8, ) -> Tuple[Tensor, QuantState]: """ @@ -967,10 +1205,14 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError( + f"Device type not supported for FP4 quantization: {A.device.type}" + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented." + ) n = A.numel() input_shape = A.shape @@ -980,10 +1222,9 @@ def quantize_4bit( blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -991,22 +1232,66 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) post_call(A.device) code = get_4bit_type(quant_type, device=A.device) @@ -1016,19 +1301,57 @@ def quantize_4bit( absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) return out, state -def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') +def dequantize_fp4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") + -def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: +def dequantize_nf4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + + +def dequantize_4bit( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", +) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1056,23 +1379,33 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Dequantized tensor. """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented." + ) if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) else: absmax = quant_state.absmax - if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1082,27 +1415,73 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out def quantize( @@ -1117,7 +1496,8 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1143,8 +1523,10 @@ def dequantize( return out * state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - ''' +def quantize_no_absmax( + A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None +) -> Tensor: + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1163,17 +1545,20 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ------- torch.Tensor: Quantized 8-bit tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - ''' +def dequantize_no_absmax( + A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None +) -> Tensor: + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1192,9 +1577,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ------- torch.Tensor: 32-bit output tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1261,16 +1647,17 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1290,7 +1677,8 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) + ct.c_int32(g.numel()), + ) post_call(prev_device) @@ -1371,7 +1759,9 @@ def optimizer_update_8bit( param_norm = torch.norm(p.data.float()) prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) + is_on_gpu( + [g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2] + ) if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit[optimizer_name][0]( get_ptr(p), @@ -1446,7 +1836,6 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1454,8 +1843,11 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( @@ -1487,6 +1879,7 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) + def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 ): @@ -1548,10 +1941,19 @@ def histogram_scatter_add_2d( maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) is_on_gpu([histogram, index1, index2, source]) - lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + lib.chistogram_scatter_add_2d( + get_ptr(histogram), + get_ptr(index1), + get_ptr(index2), + get_ptr(source), + maxdim1, + n, + ) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: raise TypeError( f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" @@ -1639,21 +2041,26 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout + def gemv_4bit( A: Tensor, B: Tensor, out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None + state=None, ): prev_device = pre_call(A.device) - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError( + "state cannot None. gem_4bit( ) requires the state from quantize_4bit( )" + ) if A.numel() != A.shape[-1]: - raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]' + ) Bshape = state.shape bout = Bshape[0] @@ -1664,7 +2071,9 @@ def gemv_4bit( if out is None: if len(A.shape) == 3: - out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + out = torch.empty( + size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device + ) else: out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) @@ -1673,7 +2082,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1]+1)//2 + ldb = (A.shape[-1] + 1) // 2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1684,21 +2093,61 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") post_call(prev_device) return out + def igemm( A: Tensor, B: Tensor, @@ -1783,8 +2232,20 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out @@ -1865,9 +2326,24 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out @@ -1876,20 +2352,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) + return torch.empty( + tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16 + ) if dimsA == 2 and out is None: out, Sout = get_transform_buffer( @@ -1940,7 +2418,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: has_error = lib.cigemmlt_turing_32( ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc @@ -1960,11 +2438,15 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") + raise NotImplementedError( + "igemmlt not available (probably built with NO_CUBLASLT)" + ) if has_error: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + print( + f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" + ) + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) @@ -1979,10 +2461,11 @@ def mm_dequant( out=None, new_row_stats=None, new_col_stats=None, - bias=None + bias=None, ): assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1990,13 +2473,9 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) assert ( new_row_stats.shape[0] == row_stats.shape[0] ), f"{new_row_stats.shape} vs {row_stats.shape}" @@ -2016,7 +2495,17 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + ptrBias, + numRows, + numCols, + ) post_call(prev_device) return out @@ -2037,13 +2526,13 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros( @@ -2059,7 +2548,9 @@ def get_colrow_absmax( prev_device = pre_call(A.device) is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + lib.cget_col_row_stats( + ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ) post_call(prev_device) if threshold > 0.0: @@ -2122,9 +2613,7 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) return CSRSparseTensor( @@ -2138,14 +2627,10 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -2170,9 +2655,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2251,12 +2734,20 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1], transpose + ) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2267,7 +2758,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2288,7 +2779,9 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError( + f"Transform function not implemented: From {from_order} to {to_order}" + ) post_call(prev_device) @@ -2297,9 +2790,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2326,7 +2817,21 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out @@ -2553,9 +3058,7 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2565,7 +3068,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2573,7 +3076,10 @@ def extract_outliers(A, SA, idx): return out + def pipeline_test(A, batch_size): out = torch.zeros_like(A) - lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + lib.cpipeline_test( + get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size) + ) return out From 83dc3637a110418c9ad90073dd4528175c1e5d36 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:15:17 +0000 Subject: [PATCH 10/18] git blame ignore prev formatting commit --- .git-blame-ignore-revs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 646c4665e..72c797d30 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -12,3 +12,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # whitespace in pyproject.toml f7b791863083429ba79dc00f925a041beab63297 + +# format bitsandbytes/functional.py +64ad928224ab1134dff416feee5e7ca663331bc0 From bd9fb62edb9bd39d72c223f79d4c7ed92871bfd5 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 00:42:27 +0000 Subject: [PATCH 11/18] preliminary (simplified) multi-backend, scaffolding + relocate first 4 API funcs --- bitsandbytes/__init__.py | 1 + bitsandbytes/backends/__init__.py | 5 + bitsandbytes/backends/_base.py | 143 +++++++++++++++++ bitsandbytes/backends/_helpers.py | 54 +++++++ bitsandbytes/backends/amd.py | 0 bitsandbytes/backends/apple.py | 0 bitsandbytes/backends/intel.py | 0 bitsandbytes/backends/nvidia.py | 254 +++++++++++++++++++++++++++++ bitsandbytes/functional.py | 259 ++++++------------------------ 9 files changed, 505 insertions(+), 211 deletions(-) create mode 100644 bitsandbytes/backends/__init__.py create mode 100644 bitsandbytes/backends/_base.py create mode 100644 bitsandbytes/backends/_helpers.py create mode 100644 bitsandbytes/backends/amd.py create mode 100644 bitsandbytes/backends/apple.py create mode 100644 bitsandbytes/backends/intel.py create mode 100644 bitsandbytes/backends/nvidia.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 3f175319a..a8948d807 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -12,6 +12,7 @@ matmul_cublas, mm_cublas, ) +from .backends import _backend as backend from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..669c0f536 --- /dev/null +++ b/bitsandbytes/backends/__init__.py @@ -0,0 +1,5 @@ +from ..cextension import lib +from ._base import COOSparseTensor +from .nvidia import CudaBackend + +_backend = CudaBackend(lib) diff --git a/bitsandbytes/backends/_base.py b/bitsandbytes/backends/_base.py new file mode 100644 index 000000000..d31f721e4 --- /dev/null +++ b/bitsandbytes/backends/_base.py @@ -0,0 +1,143 @@ +import torch + + +class COOSparseTensor: + def __init__(self, rows, cols, nnz, rowidx, colidx, values): + assert rowidx.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colidx.numel() == nnz + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowidx = rowidx + self.colidx = colidx + self.values = values + + +class BackendInterface: + _instance = None + + def __new__(cls, lib=None): + if cls._instance is None: + if lib is None: + raise ValueError( + "A 'lib' binary must be provided during the first initialization of BackendInterface." + ) + cls._instance = super().__new__(cls) + cls._instance.lib = ( + lib # Set the binary name during the first and only instantiation + ) + else: + if lib is not None: + raise ValueError( + "The BackendInterface singleton has already been initialized with a 'lib' value. Re-initialization with a new 'lib' value is not allowed." + ) + return cls._instance + + def check_matmul( + self, + A, + B, + out=None, + transposed_A=False, + transposed_B=False, + expected_type=torch.int8, + ): + """ + Checks if the matrix multiplication between A and B can be performed, considering their shapes, + whether they are transposed, and their data types. It also determines the shape of the output tensor. + + Parameters: + - A (torch.Tensor): The first matrix in the multiplication. + - B (torch.Tensor): The second matrix in the multiplication. + - out (torch.Tensor, optional): The output tensor to store the result of the multiplication. Default is None. + - transposed_A (bool, optional): Indicates if matrix A is transposed. Default is False. + - transposed_B (bool, optional): Indicates if matrix B is transposed. Default is False. + - expected_type (torch.dtype, optional): The expected data type of matrices A and B. Default is torch.int8. + + Returns: + - tuple: The shape of the output tensor resulting from the matrix multiplication. + + Raises: + - TypeError: If the data types of A or B do not match the expected type. + - ValueError: If the dimensions of A and B are not compatible for matrix multiplication. + """ + raise NotImplementedError + + # 8-bit matmul interface + def coo_zeros(self, rows, cols, nnz, device, dtype=torch.half): + rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + values = torch.zeros((nnz,), dtype=dtype, device=device) + + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + def get_colrow_absmax( + self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 + ): + raise NotImplementedError + + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + raise NotImplementedError + + def extract_outliers(self, *args, **kwargs): + raise NotImplementedError + + def igemmlt(self, *args, **kwargs): + raise NotImplementedError + + def mm_dequant(self, *args, **kwargs): + raise NotImplementedError + + # k-bit quantization interface + def create_quant_map(self, interface, quant_name): + """ + Below functions should be abstracted into a general method + "create_quant_map(interface, "quant_name")", so we can call e.g. + create_quant_map(..., quant_name='normal'): + - 'create_dynamic_map' + - 'create_fp8_map' + - 'create_linear_map' + - 'create_normal_map' + - 'create_quantile_map' + """ + raise NotImplementedError + + def estimate_quantiles(self, *args, **kwargs): + raise NotImplementedError + + def dequantize_blockwise(self, *args, **kwargs): + raise NotImplementedError + + def quantize_blockwise(self, *args, **kwargs): + raise NotImplementedError + + # 4-bit matmul interface + def dequantize_4bit(self, *args, **kwargs): + raise NotImplementedError + + def quantize_4bit(self, *args, **kwargs): + raise NotImplementedError + + def gemv_4bit(self, *args, **kwargs): + raise NotImplementedError + + # 8-bit optimizer interface + def optimizer_update_32bit(self, *args, **kwargs): + """This is needed for tests""" + raise NotImplementedError("Subclasses must implement 'optimizer_update_32bit'.") + + def optimizer_update_8bit_blockwise(self, *args, **kwargs): + raise NotImplementedError diff --git a/bitsandbytes/backends/_helpers.py b/bitsandbytes/backends/_helpers.py new file mode 100644 index 000000000..adfc2a1c2 --- /dev/null +++ b/bitsandbytes/backends/_helpers.py @@ -0,0 +1,54 @@ +import ctypes +from typing import Optional + +import torch + + +def pre_call(device): + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + return prev_device + + +def post_call(prev_device): + torch.cuda.set_device(prev_device) + + +def get_ptr(A: Optional[torch.Tensor]) -> Optional[ctypes.c_void_p]: + """ + Get the ctypes pointer from a PyTorch Tensor. + + Parameters + ---------- + A : torch.tensor + The PyTorch tensor. + + Returns + ------- + ctypes.c_void_p + """ + if A is None: + return None + else: + return ctypes.c_void_p(A.data.data_ptr()) + + +def is_on_gpu(tensors): + on_gpu = True + gpu_ids = set() + for t in tensors: + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" + ) + if len(gpu_ids) > 1: + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" + ) + return on_gpu diff --git a/bitsandbytes/backends/amd.py b/bitsandbytes/backends/amd.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/apple.py b/bitsandbytes/backends/apple.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/intel.py b/bitsandbytes/backends/intel.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/nvidia.py b/bitsandbytes/backends/nvidia.py new file mode 100644 index 000000000..632cecc11 --- /dev/null +++ b/bitsandbytes/backends/nvidia.py @@ -0,0 +1,254 @@ +import ctypes + +import torch + +from ._base import BackendInterface +from ._helpers import get_ptr, is_on_gpu, post_call, pre_call + + +class CudaBackend(BackendInterface): + def check_matmul( + self, A, B, out, transposed_A, transposed_B, expected_type=torch.int8 + ): + if not torch.cuda.is_initialized(): + torch.cuda.init() + if A.dtype != expected_type or B.dtype != expected_type: + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) + + sA = A.shape + sB = B.shape + tA = transposed_A + tB = transposed_B + + correct = True + + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False + + if out is not None: + sout = out.shape + # special case common in backprop + if not correct and len(sA) == 3 and len(sB) == 3: + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): + correct = True + else: + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) + + if not correct: + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) + + return sout + + def get_colrow_absmax( + self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 + ): + assert A.dtype == torch.float16 + device = A.device + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + if col_stats is None: + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) + + ptrA = get_ptr(A) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNnzrows = get_ptr(nnz_block_ptr) + rows = ctypes.c_int32(rows) + cols = ctypes.c_int32(cols) + + prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) + self.lib.cget_col_row_stats( + ptrA, + ptrRowStats, + ptrColStats, + ptrNnzrows, + ctypes.c_float(threshold), + rows, + cols, + ) + post_call(prev_device) + + if threshold > 0.0: + nnz_block_ptr.cumsum_(0) + + return row_stats, col_stats, nnz_block_ptr + + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = self.get_colrow_absmax( + A, threshold=threshold + ) + + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = self.coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ctypes.c_float(threshold), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ctypes.c_float(0.0), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + else: + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ctypes.c_float(threshold), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9cb7abc39..4a1cd3062 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,10 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct -from functools import reduce # Required in Python 3 +from functools import ( + reduce, # Required in Python 3 + wraps, +) import itertools import operator from typing import Any, Dict, Optional, Tuple +import warnings import numpy as np import torch @@ -14,6 +18,8 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from .backends import _backend as backend +from .backends._helpers import get_ptr, is_on_gpu, post_call, pre_call from .cextension import COMPILED_WITH_CUDA, lib @@ -22,6 +28,38 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) +def deprecated(_func=None, *, new_func_name=None): + """ + A decorator to mark functions as deprecated. It issues a warning when the decorated function is called, + advising to use a specified new function instead. + + Parameters: + - _func (callable, optional): The function to be deprecated. This is for internal use when the decorator is applied without parentheses. + - new_func_name (str, optional): The name of the new function to use instead of the deprecated one. Defaults to 'bitsandbytes.backend.'. + + Usage: + @deprecated + def old_function(): + ... + + @deprecated(new_func_name='module.new_function') + def another_old_function(): + ... + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + replacement = new_func_name or f"bitsandbytes.backend.{func.__name__}" + warning_message = f"'{func.__name__}' is deprecated and will be removed in a future version. Use '{replacement}' instead." + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator if _func is None else decorator(_func) + + name2qmap = {} if COMPILED_WITH_CUDA: @@ -417,56 +455,6 @@ def get_special_format_str(): return "col_turing" -def is_on_gpu(tensors): - on_gpu = True - gpu_ids = set() - for t in tensors: - if t is None: - continue # NULL pointers are fine - is_paged = getattr(t, "is_paged", False) - on_gpu &= t.device.type == "cuda" or is_paged - if not is_paged: - gpu_ids.add(t.device.index) - if not on_gpu: - raise TypeError( - f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" - ) - if len(gpu_ids) > 1: - raise TypeError( - f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" - ) - return on_gpu - - -def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: - """ - Get the ctypes pointer from a PyTorch Tensor. - - Parameters - ---------- - A : torch.tensor - The PyTorch tensor. - - Returns - ------- - ctypes.c_void_p - """ - if A is None: - return None - else: - return ct.c_void_p(A.data.data_ptr()) - - -def pre_call(device): - prev_device = torch.cuda.current_device() - torch.cuda.set_device(device) - return prev_device - - -def post_call(prev_device): - torch.cuda.set_device(prev_device) - - def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): @@ -2511,69 +2499,9 @@ def mm_dequant( return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): - assert A.dtype == torch.float16 - device = A.device - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) - if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) - - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats( - ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols - ) - post_call(prev_device) - - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) - - return row_stats, col_stats, nnz_block_ptr - - -class COOSparseTensor: - def __init__(self, rows, cols, nnz, rowidx, colidx, values): - assert rowidx.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colidx.numel() == nnz - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowidx = rowidx - self.colidx = colidx - self.values = values +@deprecated +def get_colrow_absmax(*args, **kwargs): + return backend.get_colrow_absmax(*args, **kwargs) class CSRSparseTensor: @@ -2633,105 +2561,14 @@ def coo2csc(cooA): return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) -def coo_zeros(rows, cols, nnz, device, dtype=torch.half): - rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - values = torch.zeros((nnz,), dtype=dtype, device=device) - return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) +@deprecated +def coo_zeros(*args, **kwargs): + return backend.coo_zeros(*args, **kwargs) -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - - return out_row, out_col, row_stats, col_stats, coo_tensor +@deprecated +def double_quant(*args, **kwargs): + return backend.double_quant(*args, **kwargs) def transform( From 01327aa0119fa503ea16322dd72f69f202e4502e Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 00:43:31 +0000 Subject: [PATCH 12/18] ruff format bitsandbytes/functional.py --- bitsandbytes/functional.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 4a1cd3062..0420bed4c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -795,14 +795,22 @@ def __eq__(self, other): return False return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) and - self.shape == other.shape and - torch.allclose(self.code, other.code, atol=1e-6) and - self.dtype == other.dtype and - self.blocksize == other.blocksize and - self.quant_type == other.quant_type and - (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and - (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) ) From 2d05dc5e275f8019bf90630609a5356270693be3 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 00:44:38 +0000 Subject: [PATCH 13/18] git blame ignore prev formatting commit --- .git-blame-ignore-revs | 1 + 1 file changed, 1 insertion(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 72c797d30..fa8ca65c0 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -15,3 +15,4 @@ f7b791863083429ba79dc00f925a041beab63297 # format bitsandbytes/functional.py 64ad928224ab1134dff416feee5e7ca663331bc0 +01327aa0119fa503ea16322dd72f69f202e4502e From 8fe3cb39c562e8f55ae3428c2ef03f39158e7de0 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:09:41 +0000 Subject: [PATCH 14/18] backends/nvidia: add notes on methods to implement/deprecate --- bitsandbytes/backends/nvidia.py | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/bitsandbytes/backends/nvidia.py b/bitsandbytes/backends/nvidia.py index 632cecc11..f4c14cfab 100644 --- a/bitsandbytes/backends/nvidia.py +++ b/bitsandbytes/backends/nvidia.py @@ -252,3 +252,42 @@ def double_quant( post_call(prev_device) return out_row, out_col, row_stats, col_stats, coo_tensor + + """ + # CUDA specific interface (do not include in general interface): + 'CUBLAS_Context' + 'Cusparse_Context' + 'GlobalPageManager' + '_mul' + 'arange' + 'dtype2bytes' + 'elementwise_func' + 'fill' + 'get_paged' + 'get_4bit_type' + 'get_ptr' + 'get_special_format_str' + 'get_transform_buffer' + 'get_transform_func' + 'is_on_gpu' + 'nvidia_transform' + 'transform' + + ## Deprecate these: + 'optimizer_update_8bit' + 'dequant_min_max' + 'dequantize' + 'dequantize_no_absmax' + 'igemm' + 'quantize' + 'spmm_coo' + 'spmm_coo_very_sparse' + 'vectorwise_dequant' + 'vectorwise_mm_dequant' + 'vectorwise_quant' + 'CSCSparseTensor' + 'CSRSparseTensor' + 'coo2csc' + 'coo2csr' + 'histogram_scatter_add_2d' + """ From 84f1ab6215e80066e1c9cf1019ae7b5570740083 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:13:35 +0000 Subject: [PATCH 15/18] make simplifying assumption explicit --- bitsandbytes/backends/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 669c0f536..743693d8d 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -3,3 +3,9 @@ from .nvidia import CudaBackend _backend = CudaBackend(lib) +# TODO: this should actually be done in `cextension.py` and potentially with .get_instance() +# for now this is just a simplifying assumption +# +# Notes from Tim: +# backend = CUDABackend.get_instance() +# -> CUDASetup -> lib -> backend.clib = lib From d8e13b721cf027b04ecfd9b25da2fc4bf253d500 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 10:04:23 +0000 Subject: [PATCH 16/18] add missing __init__.py in device_setup --- bitsandbytes/device_setup/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 bitsandbytes/device_setup/__init__.py diff --git a/bitsandbytes/device_setup/__init__.py b/bitsandbytes/device_setup/__init__.py new file mode 100644 index 000000000..e69de29bb From 4370e6158c7812fdd1e7f0b34c84cdbb32995b67 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 23 Feb 2024 10:15:13 +0000 Subject: [PATCH 17/18] trying to get docs CI to pass --- bitsandbytes/backends/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 743693d8d..5da58f25e 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -2,7 +2,7 @@ from ._base import COOSparseTensor from .nvidia import CudaBackend -_backend = CudaBackend(lib) +_backend = CudaBackend(lib) if lib else None # TODO: this should actually be done in `cextension.py` and potentially with .get_instance() # for now this is just a simplifying assumption # From 044147cc44d9f41fe205fef1b644b50d48ed2219 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Tue, 12 Mar 2024 19:38:28 +0000 Subject: [PATCH 18/18] revert formatting for clear review diff --- bitsandbytes/cextension.py | 16 +- bitsandbytes/functional.py | 1018 +++++++++--------------------------- 2 files changed, 259 insertions(+), 775 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 171572ab9..1636d06b0 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -14,27 +14,21 @@ if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() - raise RuntimeError( - """ + raise RuntimeError(''' CUDA Setup failed despite GPU being available. Please run the following command to get more information: python -m bitsandbytes Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes - and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues""" - ) - _ = ( - lib.cadam32bit_grad_fp32 - ) # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') + _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError as ex: - warn( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." - ) + warn("The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False print(str(ex)) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0420bed4c..0841d11be 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -166,6 +166,7 @@ def prefetch_all(self, to_cpu=False): prefetch_tensor(t, to_cpu) + class CUBLAS_Context: _instance = None @@ -207,7 +208,6 @@ def get_instance(cls): cls._instance.initialize() return cls._instance - dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -215,11 +215,10 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -FIRST_CUDA_DEVICE = torch.device("cuda", index=0) - +FIRST_CUDA_DEVICE = torch.device('cuda', index=0) def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): - num_bytes = dtype2bytes[dtype] * prod(shape) + num_bytes = dtype2bytes[dtype]*prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -228,35 +227,31 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): out.page_deviceid = device.index return out - def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, "Only paged tensors can be prefetched!" + assert A.is_paged, 'Only paged tensors can be prefetched!' if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype] * A.numel() + num_bytes = dtype2bytes[A.dtype]*A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) - def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f"c{func_name}_fp32", None) + func = getattr(lib, f'c{func_name}_fp32', None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f"c{func_name}_uint8", None) + func = getattr(lib, f'c{func_name}_uint8', None) cvalue = ct.c_uint8(value) - if func is None: - raise NotImplementedError(f"Function not implemented: {func_name}") + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') - is_managed = getattr(A, "is_managed", False) + is_managed = getattr(A, 'is_managed', False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: - prefetch_tensor(B) + if B is not None: prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: @@ -266,36 +261,28 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # operation occurred. So we synchronize. torch.cuda.synchronize() - -def fill(A, value, device=None, prefetch=True): - elementwise_func("fill", A, None, value) - - -def arange(A, device=None): - elementwise_func("arange", A, None, 0) - - -def _mul(A, B, device=None): - elementwise_func("_mul", A, B, 0) +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = -1.0 if signed else 0.0 + sign = (-1.0 if signed else 0.0) total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = 2**total_bits if not signed else 2**total_bits - 1 + total_values = (2**total_bits if not signed else 2**total_bits-1) values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel() // 2 # noqa: E741 - return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + l = values.numel()//2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -310,11 +297,11 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -327,39 +314,38 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values - def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e + p == total_bits - has_sign + assert e+p == total_bits-has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate( - range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1) - ): + for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) + values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - # for ev in evalues: - bias = 2 ** (exponent_bits - 1) - for evalue in range(2 ** (exponent_bits)): + #for ev in evalues: + bias = 2**(exponent_bits-1) + for evalue in range(2**(exponent_bits)): for bit_pattern in lst: - value = 1 if evalue != 0 else 0 + value = (1 if evalue != 0 else 0) for i, pval in enumerate(list(bit_pattern)): - value += pval * (2 ** -(i + 1)) + value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value * 2**-(bias) + value = value*2**-(bias) else: # normals - value = value * 2 ** -(evalue - bias - 1) + value = value*2**-(evalue-bias-1) values.append(value) if signed: values.append(-value) + assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -373,6 +359,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code + def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -397,11 +384,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int( - 2 ** (i + non_sign_bits - max_exponent_bits) + 1 - if signed - else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1 - ) + fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -427,9 +410,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) - def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) + q = estimate_quantiles(A, num_quantiles=2**total_bits-1) q = q.tolist() q.append(0) @@ -440,13 +422,11 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q / q.abs().max() + q = q/q.abs().max() return q - def get_special_format_str(): - if not torch.cuda.is_available(): - return "col_turing" + if not torch.cuda.is_available(): return 'col_turing' major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -547,13 +527,8 @@ def nvidia_transform( return out, new_state -def estimate_quantiles( - A: Tensor, - out: Optional[torch.Tensor] = None, - offset: float = 1 / 512, - num_quantiles=256, -) -> Tensor: - """ +def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: + ''' Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -580,37 +555,26 @@ def estimate_quantiles( ------- torch.Tensor: The 256 quantiles in float32 datatype. - """ - if A.numel() < 256: - raise NotImplementedError( - f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values." - ) - if num_quantiles > 256: - raise NotImplementedError( - f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}" - ) - if num_quantiles < 256 and offset == 1 / (512): + ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): # override default arguments - offset = 1 / (2 * num_quantiles) + offset = 1/(2*num_quantiles) - if out is None: - out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) else: raise NotImplementedError(f"Not supported data type {A.dtype}") post_call(device) if num_quantiles < 256: - step = round(256 / num_quantiles) + step = round(256/num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -619,35 +583,12 @@ def estimate_quantiles( class QuantState: """container for quantization state components to work with Params4bit and similar classes""" - - valid_quant_types = ("fp4", "nf4") + valid_quant_types = ('fp4', 'nf4') valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = [ - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "quant_state", - "quant_type", - "blocksize", - "dtype", - "shape", - "nested_blocksize", - "nested_dtype", - "nested_offset", - ] - - def __init__( - self, - absmax, - shape=None, - code=None, - blocksize=None, - quant_type=None, - dtype=None, - offset=None, - state2=None, - ): + valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', + 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + + def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): self.absmax = absmax self.shape = shape self.code = code @@ -666,27 +607,13 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - [self.offset, self.state2], - self.quant_type, - ] + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] else: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - None, - self.quant_type, - ] + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -697,48 +624,38 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState """ # unpacking tensor with non-tensor components - qs_key = [ - k - for k, v in qs_dict.items() - if "quant_state" in k and isinstance(v, torch.Tensor) - ] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError( - "Expected packed or unpacked quant_state items, found neither" - ) + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and 'quant_type' not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}." - ) + raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: first_qs_key = qs_key[0] qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if "nested_absmax" in qs_dict: - offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + if 'nested_absmax' in qs_dict: + offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) state2 = cls( - absmax=qs_dict["nested_absmax"].to(device), - blocksize=qs_dict["nested_blocksize"], - code=qs_dict["nested_quant_map"].to(device), - dtype=getattr(torch, qs_dict["nested_dtype"]), + absmax=qs_dict['nested_absmax'].to(device), + blocksize=qs_dict['nested_blocksize'], + code=qs_dict['nested_quant_map'].to(device), + dtype=getattr(torch, qs_dict['nested_dtype']), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict["quant_type"], - absmax=qs_dict["absmax"].to(device), - blocksize=qs_dict["blocksize"], - code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), - shape=torch.Size(qs_dict["shape"]) - if qs_dict["shape"] is not None - else None, + quant_type=qs_dict['quant_type'], + absmax=qs_dict['absmax'].to(device), + blocksize=qs_dict['blocksize'], + code=qs_dict['quant_map'].to(device), + dtype=getattr(torch, qs_dict['dtype']), + shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, offset=offset, state2=state2, ) @@ -750,36 +667,28 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - "quant_type": self.quant_type, - "absmax": self.absmax, - "blocksize": self.blocksize, - "quant_map": self.code, - "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), + 'quant_type': self.quant_type, + 'absmax': self.absmax, + 'blocksize': self.blocksize, + 'quant_map': self.code, + 'dtype': str(self.dtype).strip('torch.'), + 'shape': tuple(self.shape), } if self.nested: - qs_dict.update( - { - "nested_absmax": self.state2.absmax, - "nested_blocksize": self.state2.blocksize, - "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - "nested_dtype": str(self.state2.dtype).strip("torch."), - "nested_offset": self.offset.item(), - } - ) + qs_dict.update({ + 'nested_absmax': self.state2.absmax, + 'nested_blocksize': self.state2.blocksize, + 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + 'nested_dtype': str(self.state2.dtype).strip('torch.'), + 'nested_offset': self.offset.item(), + }) if not packed: return qs_dict # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = { - k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor) - } - non_tensor_dict = { - k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor) - } - qs_packed_dict[ - "quant_state." + "bitsandbytes__" + self.quant_type - ] = pack_dict_to_tensor(non_tensor_dict) + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) return qs_packed_dict def to(self, device): @@ -795,22 +704,14 @@ def __eq__(self, other): return False return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) - and self.shape == other.shape - and torch.allclose(self.code, other.code, atol=1e-6) - and self.dtype == other.dtype - and self.blocksize == other.blocksize - and self.quant_type == other.quant_type - and ( - self.offset == other.offset - if self.offset is not None and other.offset is not None - else self.offset is other.offset - ) - and ( - self.state2 == other.state2 - if self.state2 is not None and other.state2 is not None - else self.state2 is other.state2 - ) + torch.allclose(self.absmax, other.absmax, atol=1e-6) and + self.shape == other.shape and + torch.allclose(self.code, other.code, atol=1e-6) and + self.dtype == other.dtype and + self.blocksize == other.blocksize and + self.quant_type == other.quant_type and + (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and + (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) ) @@ -848,6 +749,7 @@ def quantize_blockwise( The quantization state to undo the quantization. """ + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -862,72 +764,33 @@ def quantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != "cpu": + if A.device.type != 'cpu': assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) + quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) else: - quant_state = QuantState( - absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype - ) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) return out, quant_state @@ -939,7 +802,7 @@ def dequantize_blockwise( code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -973,80 +836,43 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState( - absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32 - ) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != "cpu": + if A.device.type != 'cpu': device = pre_call(A.device) code = quant_state.code.to(A.device) if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" - ) + raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(quant_state.absmax), - get_ptr(out), - ct.c_longlong(quant_state.blocksize), - ct.c_longlong(A.numel()), - ) + lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) return out - def get_4bit_type(typename, device=None, blocksize=64): - if device is None: - device = "cuda" + if device is None: device = 'cuda' data = None - if typename == "nf4": - """ Implements the NF4 data type. + if typename == 'nf4': + ''' Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -1055,26 +881,12 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - """ - data = [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ] - elif typename == "fp4": + ''' + data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, + -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, + 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0] + elif typename == 'fp4': # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -1084,55 +896,21 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b110 = 2 # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) - data = [ - 0, - 0.0625, - 8.0, - 12.0, - 4.0, - 6.0, - 2.0, - 3.0, - -0, - -0.0625, - -8.0, - -12.0, - -4.0, - -6.0, - -2.0, - -3.0, - ] - elif typename == "int4": + data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] + elif typename == 'int4': data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == "af4": + elif typename == 'af4': # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [ - -1.0, - -0.69441008, - -0.51243739, - -0.3736951, - -0.25607552, - -0.14982478, - -0.04934812, - 0.0, - 0.04273164, - 0.12934483, - 0.21961274, - 0.31675666, - 0.42563882, - 0.55496234, - 0.72424863, - 1.0, - ][::-1] + data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, + -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, + 0.42563882, 0.55496234, 0.72424863, 1.][::-1] else: - raise NotImplementedError( - "4-bit AbnormalFloats currently only support blocksize 64." - ) + raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') if data is None: - raise NotImplementedError(f"Typename {typename} not supported") + raise NotImplementedError(f'Typename {typename} not supported') data = Tensor(data) data /= data.abs().max() @@ -1141,30 +919,11 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4( - A: Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_storage=torch.uint8, -): - return quantize_4bit( - A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage - ) - +def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) -def quantize_nf4( - A: Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_storage=torch.uint8, -): - return quantize_4bit( - A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage - ) +def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) def quantize_4bit( @@ -1173,7 +932,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type="fp4", + quant_type='fp4', quant_storage=torch.uint8, ) -> Tuple[Tensor, QuantState]: """ @@ -1201,14 +960,10 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != "cuda": - raise NotImplementedError( - f"Device type not supported for FP4 quantization: {A.device.type}" - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError( - f"4-bit quantization data type {quant_type} is not implemented." - ) + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') n = A.numel() input_shape = A.shape @@ -1218,9 +973,10 @@ def quantize_4bit( blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -1228,66 +984,22 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) elif A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) code = get_4bit_type(quant_type, device=A.device) @@ -1297,57 +1009,19 @@ def quantize_4bit( absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) + state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) + state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) return out, state +def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_fp4( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, -) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") - - -def dequantize_nf4( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, -) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") - +def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", -) -> Tensor: +def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1375,33 +1049,23 @@ def dequantize_4bit( Dequantized tensor. """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError( - f"4-bit quantization data type {quant_type} is not implemented." - ) + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) + quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) else: absmax = quant_state.absmax + if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1411,73 +1075,27 @@ def dequantize_4bit( device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = True if A.shape[0] == 1 else False - if is_transposed: - return out.t() - else: - return out + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out def quantize( @@ -1492,8 +1110,7 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1519,10 +1136,8 @@ def dequantize( return out * state[0] -def quantize_no_absmax( - A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None -) -> Tensor: - """ +def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + ''' Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1541,20 +1156,17 @@ def quantize_no_absmax( ------- torch.Tensor: Quantized 8-bit tensor. - """ + ''' prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) return out -def dequantize_no_absmax( - A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None -) -> Tensor: - """ +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + ''' Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1573,10 +1185,9 @@ def dequantize_no_absmax( ------- torch.Tensor: 32-bit output tensor. - """ + ''' prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) + if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1643,17 +1254,16 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) + optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1673,8 +1283,7 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + ct.c_int32(g.numel())) post_call(prev_device) @@ -1755,9 +1364,7 @@ def optimizer_update_8bit( param_norm = torch.norm(p.data.float()) prev_device = pre_call(g.device) - is_on_gpu( - [g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2] - ) + is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit[optimizer_name][0]( get_ptr(p), @@ -1832,6 +1439,7 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: + optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1839,11 +1447,8 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( @@ -1875,7 +1480,6 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) - def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 ): @@ -1937,19 +1541,10 @@ def histogram_scatter_add_2d( maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) is_on_gpu([histogram, index1, index2, source]) - lib.chistogram_scatter_add_2d( - get_ptr(histogram), - get_ptr(index1), - get_ptr(index2), - get_ptr(source), - maxdim1, - n, - ) - + lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): - torch.cuda.init() + if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: raise TypeError( f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" @@ -2037,26 +1632,21 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout - def gemv_4bit( A: Tensor, B: Tensor, out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None, + state=None ): prev_device = pre_call(A.device) - # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError( - "state cannot None. gem_4bit( ) requires the state from quantize_4bit( )" - ) + raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') if A.numel() != A.shape[-1]: - raise ValueError( - 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]' - ) + raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') Bshape = state.shape bout = Bshape[0] @@ -2067,9 +1657,7 @@ def gemv_4bit( if out is None: if len(A.shape) == 3: - out = torch.empty( - size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device - ) + out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) else: out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) @@ -2078,7 +1666,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1] + 1) // 2 + ldb = (A.shape[-1]+1)//2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -2089,61 +1677,21 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') post_call(prev_device) return out - def igemm( A: Tensor, B: Tensor, @@ -2228,20 +1776,8 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ) + lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) return out @@ -2322,24 +1858,9 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ct.c_long(strideA), - ct.c_long(strideB), - ct.c_long(strideC), - ct.c_uint32(num_batch), - ) + lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), + ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out @@ -2348,22 +1869,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" + assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) elif shapeA[1] == 0 and dimsA == 3: - return torch.empty( - tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16 - ) + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: out, Sout = get_transform_buffer( @@ -2414,7 +1933,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == "col_turing": + if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32( ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc @@ -2434,15 +1953,11 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError( - "igemmlt not available (probably built with NO_CUBLASLT)" - ) + raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") if has_error: - print( - f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" - ) - raise Exception("cublasLt ran into an error!") + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') + raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) @@ -2457,11 +1972,10 @@ def mm_dequant( out=None, new_row_stats=None, new_col_stats=None, - bias=None, + bias=None ): assert A.dtype == torch.int32 - if bias is not None: - assert bias.dtype == torch.float16 + if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -2469,9 +1983,13 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) assert ( new_row_stats.shape[0] == row_stats.shape[0] ), f"{new_row_stats.shape} vs {row_stats.shape}" @@ -2491,17 +2009,7 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrNewRowStats, - ptrNewColStats, - ptrBias, - numRows, - numCols, - ) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) post_call(prev_device) return out @@ -2549,7 +2057,9 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) return CSRSparseTensor( @@ -2563,11 +2073,14 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) - + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) @deprecated def coo_zeros(*args, **kwargs): @@ -2579,20 +2092,12 @@ def double_quant(*args, **kwargs): return backend.double_quant(*args, **kwargs) -def transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None -): +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1], transpose - ) - else: - new_state = (state[0], to_order) # (shape, order) + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2603,7 +2108,7 @@ def transform( dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == "col32": + if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2624,9 +2129,7 @@ def transform( elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError( - f"Transform function not implemented: From {from_order} to {to_order}" - ) + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') post_call(prev_device) @@ -2635,7 +2138,9 @@ def transform( def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2662,21 +2167,7 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo( - ptr, - ptrRowidx, - ptrColidx, - ptrValues, - cnnz, - crowsA, - ccolsA, - ccolsB, - cldb, - ptrB, - cldc, - ptrC, - ct.c_bool(transposed_B), - ) + lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) return out @@ -2903,7 +2394,9 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2913,7 +2406,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == "col_turing": + if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2921,10 +2414,7 @@ def extract_outliers(A, SA, idx): return out - def pipeline_test(A, batch_size): out = torch.zeros_like(A) - lib.cpipeline_test( - get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size) - ) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) return out