Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit a368310

Browse files
committed
Allow loading model with 4bit quantization.
For detail on 4bit options, see: https://huggingface.co/blog/4bit-transformers-bitsandbytes
1 parent 5ef5ef0 commit a368310

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

basaran/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def is_true(value):
2121
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
2222
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
2323
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
24+
MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", ""))
25+
MODEL_4BIT_QUANT_TYPE = os.getenv("MODEL_4BIT_QUANT_TYPE", "fp4")
26+
MODEL_4BIT_DOUBLE_QUANT = is_true(os.getenv("MODEL_4BIT_DOUBLE_QUANT", ""))
2427
MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", ""))
2528
MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", ""))
2629
MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", ""))

basaran/__main__.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from . import MODEL_REVISION
2121
from . import MODEL_CACHE_DIR
2222
from . import MODEL_LOAD_IN_8BIT
23+
from . import MODEL_LOAD_IN_4BIT
24+
from . import MODEL_4BIT_QUANT_TYPE
25+
from . import MODEL_4BIT_DOUBLE_QUANT
2326
from . import MODEL_LOCAL_FILES_ONLY
2427
from . import MODEL_TRUST_REMOTE_CODE
2528
from . import MODEL_HALF_PRECISION
@@ -42,6 +45,9 @@
4245
revision=MODEL_REVISION,
4346
cache_dir=MODEL_CACHE_DIR,
4447
load_in_8bit=MODEL_LOAD_IN_8BIT,
48+
load_in_4bit=MODEL_LOAD_IN_4BIT,
49+
quant_type=MODEL_4BIT_QUANT_TYPE,
50+
double_quant=MODEL_4BIT_DOUBLE_QUANT,
4551
local_files_only=MODEL_LOCAL_FILES_ONLY,
4652
trust_remote_code=MODEL_TRUST_REMOTE_CODE,
4753
half_precision=MODEL_HALF_PRECISION,

basaran/model.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MinNewTokensLengthLogitsProcessor,
1313
TemperatureLogitsWarper,
1414
TopPLogitsWarper,
15+
BitsAndBytesConfig
1516
)
1617

1718
from .choice import map_choice
@@ -302,6 +303,9 @@ def load_model(
302303
revision=None,
303304
cache_dir=None,
304305
load_in_8bit=False,
306+
load_in_4bit=False,
307+
quant_type="fp4",
308+
double_quant=False,
305309
local_files_only=False,
306310
trust_remote_code=False,
307311
half_precision=False,
@@ -319,12 +323,27 @@ def load_model(
319323

320324
# Set device mapping and quantization options if CUDA is available.
321325
if torch.cuda.is_available():
326+
# Set quantization options if specified.
327+
quant_config = None
328+
if load_in_8bit and load_in_4bit:
329+
raise ValueError("Only one of load_in_8bit and load_in_4bit can be True")
330+
if load_in_8bit:
331+
quant_config = BitsAndBytesConfig(
332+
load_in_8bit=True,
333+
)
334+
elif load_in_4bit:
335+
quant_config = BitsAndBytesConfig(
336+
load_in_4bit=True,
337+
bnb_4bit_quant_type=quant_type,
338+
bnb_4bit_use_double_quant=double_quant,
339+
bnb_4bit_compute_dtype=torch.bfloat16,
340+
)
322341
kwargs = kwargs.copy()
323342
kwargs["device_map"] = "auto"
324-
kwargs["load_in_8bit"] = load_in_8bit
343+
kwargs["quantization_config"] = quant_config
325344

326345
# Cast all parameters to float16 if quantization is enabled.
327-
if half_precision or load_in_8bit:
346+
if half_precision or load_in_8bit or load_in_4bit:
328347
kwargs["torch_dtype"] = torch.float16
329348

330349
# Support both decoder-only and encoder-decoder models.

0 commit comments

Comments
 (0)