This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree 3 files changed +30
-2
lines changed
3 files changed +30
-2
lines changed Original file line number Diff line number Diff line change @@ -21,6 +21,9 @@ def is_true(value):
21
21
MODEL_REVISION = os .getenv ("MODEL_REVISION" , "" )
22
22
MODEL_CACHE_DIR = os .getenv ("MODEL_CACHE_DIR" , "models" )
23
23
MODEL_LOAD_IN_8BIT = is_true (os .getenv ("MODEL_LOAD_IN_8BIT" , "" ))
24
+ MODEL_LOAD_IN_4BIT = is_true (os .getenv ("MODEL_LOAD_IN_4BIT" , "" ))
25
+ MODEL_4BIT_QUANT_TYPE = os .getenv ("MODEL_4BIT_QUANT_TYPE" , "fp4" )
26
+ MODEL_4BIT_DOUBLE_QUANT = is_true (os .getenv ("MODEL_4BIT_DOUBLE_QUANT" , "" ))
24
27
MODEL_LOCAL_FILES_ONLY = is_true (os .getenv ("MODEL_LOCAL_FILES_ONLY" , "" ))
25
28
MODEL_TRUST_REMOTE_CODE = is_true (os .getenv ("MODEL_TRUST_REMOTE_CODE" , "" ))
26
29
MODEL_HALF_PRECISION = is_true (os .getenv ("MODEL_HALF_PRECISION" , "" ))
Original file line number Diff line number Diff line change 20
20
from . import MODEL_REVISION
21
21
from . import MODEL_CACHE_DIR
22
22
from . import MODEL_LOAD_IN_8BIT
23
+ from . import MODEL_LOAD_IN_4BIT
24
+ from . import MODEL_4BIT_QUANT_TYPE
25
+ from . import MODEL_4BIT_DOUBLE_QUANT
23
26
from . import MODEL_LOCAL_FILES_ONLY
24
27
from . import MODEL_TRUST_REMOTE_CODE
25
28
from . import MODEL_HALF_PRECISION
42
45
revision = MODEL_REVISION ,
43
46
cache_dir = MODEL_CACHE_DIR ,
44
47
load_in_8bit = MODEL_LOAD_IN_8BIT ,
48
+ load_in_4bit = MODEL_LOAD_IN_4BIT ,
49
+ quant_type = MODEL_4BIT_QUANT_TYPE ,
50
+ double_quant = MODEL_4BIT_DOUBLE_QUANT ,
45
51
local_files_only = MODEL_LOCAL_FILES_ONLY ,
46
52
trust_remote_code = MODEL_TRUST_REMOTE_CODE ,
47
53
half_precision = MODEL_HALF_PRECISION ,
Original file line number Diff line number Diff line change 12
12
MinNewTokensLengthLogitsProcessor ,
13
13
TemperatureLogitsWarper ,
14
14
TopPLogitsWarper ,
15
+ BitsAndBytesConfig
15
16
)
16
17
17
18
from .choice import map_choice
@@ -302,6 +303,9 @@ def load_model(
302
303
revision = None ,
303
304
cache_dir = None ,
304
305
load_in_8bit = False ,
306
+ load_in_4bit = False ,
307
+ quant_type = "fp4" ,
308
+ double_quant = False ,
305
309
local_files_only = False ,
306
310
trust_remote_code = False ,
307
311
half_precision = False ,
@@ -319,12 +323,27 @@ def load_model(
319
323
320
324
# Set device mapping and quantization options if CUDA is available.
321
325
if torch .cuda .is_available ():
326
+ # Set quantization options if specified.
327
+ quant_config = None
328
+ if load_in_8bit and load_in_4bit :
329
+ raise ValueError ("Only one of load_in_8bit and load_in_4bit can be True" )
330
+ if load_in_8bit :
331
+ quant_config = BitsAndBytesConfig (
332
+ load_in_8bit = True ,
333
+ )
334
+ elif load_in_4bit :
335
+ quant_config = BitsAndBytesConfig (
336
+ load_in_4bit = True ,
337
+ bnb_4bit_quant_type = quant_type ,
338
+ bnb_4bit_use_double_quant = double_quant ,
339
+ bnb_4bit_compute_dtype = torch .bfloat16 ,
340
+ )
322
341
kwargs = kwargs .copy ()
323
342
kwargs ["device_map" ] = "auto"
324
- kwargs ["load_in_8bit " ] = load_in_8bit
343
+ kwargs ["quantization_config " ] = quant_config
325
344
326
345
# Cast all parameters to float16 if quantization is enabled.
327
- if half_precision or load_in_8bit :
346
+ if half_precision or load_in_8bit or load_in_4bit :
328
347
kwargs ["torch_dtype" ] = torch .float16
329
348
330
349
# Support both decoder-only and encoder-decoder models.
You can’t perform that action at this time.
0 commit comments