Skip to content

Commit 81e6345

Browse files
matthewdouglasakx
andauthored
LLM.int8() Refactoring: Part 1 (#1401)
* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation * Fix unintended change * New naive mm_dequant kernel for row-major; cleanup * fix * int8 refactor: initial sparse decomp, cleanup * Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup * int8: inference optimizations, some cleanup * int8: more tests passing, cleanup * int8 - more cleanup, most tests passing * int8: specify CUDA stream for int8 ops * perf: reduce overhead from getting cudaStream ptr * Mark some functions for deprecation. * int8 sparse decomp: small perf improvement * update setup.py * Update bitsandbytes/autograd/_functions.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/research/autograd/_functions.py Co-authored-by: Aarni Koskela <[email protected]> * int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn * int8 cleanup * Ignore ruff rule ISC001 (incompatible with formatter) * add comment * int8 more cleanup * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * int8: rename / deprecate old fn signatures * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * type annotation * format update * Update bitsandbytes/research/autograd/_functions.py Co-authored-by: Aarni Koskela <[email protected]> * cleanup * Add comment to explain division optimization * more cleanup * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela <[email protected]> * cleanup * Type annotations, cleanup * remove unused kernels; improved type annotations * small perf optimization for single-GPU systems * small perf optimization for single-GPU systems * update docstrings * Improve docs and tests * Update docstring * Update test * add benchmarking script * test cleanup: add deprecated marker, move benchmarks out * Add int8 dequant function; misc improvements * int8 matmul fallback for inner dims not divisible by 4 * improve register usage of kInt8VectorQuant - especially for A100/H100 * disable fail-fast for package build * maxwell compat * ptxas verbose * docs update * doc update * backward fix * Bugfix sparse decomp * Int8 fix for PEFT OLoRA init * Fix test for deprecated spmm_coo * test improvement * doc update * typo * doc cleanup * docs * add inference benchmark script * Add benchmarks, doc update --------- Co-authored-by: Aarni Koskela <[email protected]>
1 parent 7dca700 commit 81e6345

39 files changed

+2626
-2323
lines changed

.github/scripts/build-cuda.sh

+15-15
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
88
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
99
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
1010
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
11-
for NO_CUBLASLT in ON OFF; do
12-
if [ "${build_os:0:6}" == ubuntu ]; then
13-
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
14-
echo "Using image $image"
15-
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
16-
"apt-get update \
17-
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
18-
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
19-
&& cmake --build ."
20-
else
21-
pip install cmake==3.28.3
22-
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
23-
cmake --build . --config Release
24-
fi
25-
done
11+
12+
if [ "${build_os:0:6}" == ubuntu ]; then
13+
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
14+
echo "Using image $image"
15+
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
16+
"apt-get update \
17+
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
18+
&& cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
19+
&& cmake --build ."
20+
else
21+
pip install cmake==3.28.3
22+
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
23+
cmake --build . --config Release
24+
fi
25+
2626

2727
output_dir="output/${build_os}/${build_arch}"
2828
mkdir -p "${output_dir}"

.github/workflows/python-package.yml

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
##
6161
build-shared-libs-cuda:
6262
strategy:
63+
fail-fast: false
6364
matrix:
6465
os: [ubuntu-latest, windows-latest]
6566
arch: [x86_64, aarch64]

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ CMakeFiles/
2222
bitsandbytes.dir/
2323
Debug/
2424
Release/
25+
cmake-build-*/
2526

2627
# IDE local files
2728
.vs/
29+
.idea/
2830

2931
# Distribution / packaging
3032
.Python

CMakeLists.txt

+1-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# For MSVC: `cmake -B build . && cmake --build build --config Release`
55
# You can also use the following options and variables
66
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
7-
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
87
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
98
# is whatever CMake finds on your path.
109
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4746
if(APPLE)
4847
message(FATAL_ERROR "CUDA is not supported on macOS" )
4948
endif()
50-
option(NO_CUBLASLT "Disable CUBLAS" OFF)
5149
set(BUILD_CUDA ON)
5250
set(BUILD_MPS OFF)
53-
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
5451
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5552
if(NOT APPLE)
5653
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -166,9 +163,6 @@ if(BUILD_CUDA)
166163
list(APPEND SRC_FILES ${CUDA_FILES})
167164

168165
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
169-
if(NO_CUBLASLT)
170-
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
171-
endif()
172166
add_compile_definitions(BUILD_CUDA)
173167
elseif(BUILD_MPS)
174168
if(NOT APPLE)
@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)
212206

213207
if(BUILD_CUDA)
214208
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
215-
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
216-
if(NO_CUBLASLT)
217-
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
218-
else()
219-
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
220-
endif()
221-
209+
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
222210
set_target_properties(bitsandbytes
223211
PROPERTIES
224212
CUDA_SEPARABLE_COMPILATION ON

benchmarking/README.md

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Benchmarking
2+
3+
## Inference
4+
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
5+
6+
See the example script in
7+
[inference_benchmark.py](inference_benchmark.py).
8+
9+
### Results (as of v0.45.0)
10+
11+
Our overall benchmarking results compared with v0.44.1 provide the following insights:
12+
#### LLM.int8()
13+
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
14+
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.
15+
16+
#### NF4/FP4
17+
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
18+
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.
19+
20+
Summaries with the benchmarking results are provided below.
21+
22+
#### NVIDIA T4 16GB
23+
<details>
24+
<summary>Qwen 2.5 3B Instruct</summary>
25+
26+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
27+
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
28+
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
29+
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
30+
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
31+
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
32+
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
33+
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
34+
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
35+
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
36+
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
37+
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
38+
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
39+
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
40+
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
41+
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
42+
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
43+
</details>
44+
45+
#### NVIDIA RTX 4090 24GB
46+
<details>
47+
<summary>Llama 3.1 8B</summary>
48+
49+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
50+
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
51+
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
52+
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
53+
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
54+
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
55+
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
56+
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
57+
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
58+
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
59+
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
60+
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
61+
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
62+
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
63+
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
64+
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
65+
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
66+
</details>
67+
68+
<details>
69+
<summary>Qwen 2.5 14B Instruct</summary>
70+
71+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
72+
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
73+
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
74+
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
75+
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
76+
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
77+
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
78+
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
79+
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
80+
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
81+
</details>
82+
83+
84+
#### NVIDIA H100 80GB SXM
85+
<details>
86+
<summary>Llama 3.1 8B</summary>
87+
88+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
89+
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
90+
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
91+
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
92+
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
93+
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
94+
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
95+
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
96+
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
97+
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
98+
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
99+
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
100+
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
101+
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
102+
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
103+
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
104+
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
105+
</details>
106+
107+
<details>
108+
<summary>Qwen 2.5 32B Instruct</summary>
109+
110+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
111+
|-------------|------------|-----------------------------------------|-----------------------------------|
112+
| BF16 | 1 | 0.0508 | 19.67 |
113+
| NF4 | 1 | 0.0707 | 14.14 |
114+
| NF4+DQ | 1 | 0.0860 | 11.63 |
115+
| INT8 | 1 | 0.1031 | 9.70 |
116+
| INT8+Decomp | 1 | 0.1820 | 5.49 |
117+
| BF16 | 8 | 0.0525 | 152.50 |
118+
| NF4 | 8 | 0.1154 | 69.35 |
119+
| NF4+DQ | 8 | 0.1209 | 66.19 |
120+
| INT8 | 8 | 0.1078 | 74.24 |
121+
| INT8+Decomp | 8 | 0.1958 | 40.87 |
122+
| BF16 | 32 | 0.0547 | 584.54 |
123+
| NF4 | 32 | 0.1246 | 256.84 |
124+
| NF4+DQ | 32 | 0.1298 | 246.47 |
125+
| INT8 | 32 | 0.1056 | 302.96 |
126+
| INT8+Decomp | 32 | 0.2027 | 157.83 |
127+
</details>
128+
129+
<details>
130+
<summary>Llama 3.1 70B</summary>
131+
132+
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
133+
|-------------|------------|-----------------------------------------|-----------------------------------|
134+
| NF4 | 1 | 0.0833 | 12.00 |
135+
| NF4+DQ | 1 | 0.1052 | 9.50 |
136+
| INT8 | 1 | 0.1294 | 7.73 |
137+
| INT8+Decomp | 1 | 0.1985 | 5.04 |
138+
| NF4 | 8 | 0.2348 | 34.07 |
139+
| NF4+DQ | 8 | 0.2423 | 33.01 |
140+
| INT8 | 8 | 0.1313 | 60.94 |
141+
| INT8+Decomp | 8 | 0.2052 | 38.99 |
142+
| NF4 | 32 | 0.2491 | 128.46 |
143+
| NF4+DQ | 32 | 0.2580 | 124.04 |
144+
| INT8 | 32 | 0.1314 | 243.45 |
145+
| INT8+Decomp | 32 | 0.2189 | 146.19 |
146+
</details>
147+
148+
#### Software Configuration
149+
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).
150+
151+
For all hardware configurations, we used the following dependencies:
152+
* `transformers==4.46.3`
153+
* `accelerate==1.1.1`
154+
* `tokenizers==0.20.3`
155+
* `torch==2.5.1`
156+
* `bitsandbytes==0.44.1`
157+
* `bitsandbytes==0.45.0.dev`
158+
159+
In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.

benchmarking/inference_benchmark.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Inference benchmarking tool.
3+
4+
Requirements:
5+
transformers
6+
accelerate
7+
bitsandbytes
8+
optimum-benchmark
9+
10+
Usage: python inference_benchmark.py model_id
11+
12+
options:
13+
-h, --help show this help message and exit
14+
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
15+
--bf16
16+
--fp16
17+
--nf4
18+
--nf4-dq
19+
--int8
20+
--int8-decomp
21+
--batches BATCHES [BATCHES ...]
22+
--input-length INPUT_LENGTH
23+
--out-dir OUT_DIR
24+
"""
25+
26+
import argparse
27+
from pathlib import Path
28+
29+
from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
30+
from optimum_benchmark.logging_utils import setup_logging
31+
import torch
32+
33+
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
34+
35+
WEIGHTS_CONFIGS = {
36+
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
37+
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
38+
"nf4": {
39+
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
40+
"quantization_scheme": "bnb",
41+
"quantization_config": {
42+
"load_in_4bit": True,
43+
"bnb_4bit_quant_type": "nf4",
44+
"bnb_4bit_use_double_quant": False,
45+
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
46+
},
47+
},
48+
"nf4-dq": {
49+
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
50+
"quantization_scheme": "bnb",
51+
"quantization_config": {
52+
"load_in_4bit": True,
53+
"bnb_4bit_quant_type": "nf4",
54+
"bnb_4bit_use_double_quant": True,
55+
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
56+
},
57+
},
58+
"int8-decomp": {
59+
"torch_dtype": "float16",
60+
"quantization_scheme": "bnb",
61+
"quantization_config": {
62+
"load_in_8bit": True,
63+
"llm_int8_threshold": 6.0,
64+
},
65+
},
66+
"int8": {
67+
"torch_dtype": "float16",
68+
"quantization_scheme": "bnb",
69+
"quantization_config": {
70+
"load_in_8bit": True,
71+
"llm_int8_threshold": 0.0,
72+
},
73+
},
74+
}
75+
76+
if __name__ == "__main__":
77+
setup_logging(level="INFO")
78+
79+
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
80+
81+
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
82+
83+
parser.add_argument(
84+
"--configs",
85+
nargs="+",
86+
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
87+
default=["nf4", "int8", "int8-decomp"],
88+
)
89+
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
90+
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
91+
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
92+
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
93+
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
94+
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")
95+
96+
parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
97+
parser.add_argument("--input-length", type=int, default=64)
98+
99+
parser.add_argument("--out-dir", type=str, default="reports")
100+
101+
args = parser.parse_args()
102+
103+
out_dir = Path(args.out_dir)
104+
out_dir.mkdir(parents=True, exist_ok=True)
105+
106+
for batch_size in args.batches:
107+
print(f"Benchmarking batch size: {batch_size}")
108+
for config in args.configs:
109+
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
110+
scenario_config = InferenceConfig(
111+
latency=True,
112+
memory=True,
113+
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
114+
)
115+
backend_config = PyTorchConfig(
116+
device="cuda",
117+
device_ids="0",
118+
device_map="auto",
119+
no_weights=False,
120+
model=args.model_id,
121+
**WEIGHTS_CONFIGS[config],
122+
)
123+
benchmark_config = BenchmarkConfig(
124+
name=f"benchmark-{config}-bsz{batch_size}",
125+
scenario=scenario_config,
126+
launcher=launcher_config,
127+
backend=backend_config,
128+
)
129+
130+
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
131+
132+
benchmark_report = Benchmark.launch(benchmark_config)
133+
benchmark_report.log()
134+
benchmark_report.save_json(out_path)

0 commit comments

Comments
 (0)