@@ -1751,7 +1751,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17511751 https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
17521752 Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
17531753 not compiled and adapted for CPUs.
1754- int8_threshold (`float`, *optional*, defaults to 6):
1754+ load_in_8bit_threshold (`float`, *optional*, defaults to 6):
17551755 Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
17561756 described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
17571757 states value that is above this threshold will be considered an outlier and the operation on those
@@ -1761,6 +1761,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17611761 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
17621762 penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
17631763 (small models, fine-tuning).
1764+ load_in_8bit_skip_modules (`List[str]`, *optional*):
1765+ An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
1766+ as Jukebox that has several heads in different places and not necessarily at the last position.
17641767 subfolder (`str`, *optional*, defaults to `""`):
17651768 In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
17661769 specify the folder name here.
@@ -1852,7 +1855,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
18521855 offload_folder = kwargs .pop ("offload_folder" , None )
18531856 offload_state_dict = kwargs .pop ("offload_state_dict" , False )
18541857 load_in_8bit = kwargs .pop ("load_in_8bit" , False )
1855- int8_threshold = kwargs .pop ("int8_threshold" , 6.0 )
1858+ load_in_8bit_threshold = kwargs .pop ("load_in_8bit_threshold" , 6.0 )
1859+ load_in_8bit_skip_modules = kwargs .pop ("load_in_8bit_skip_modules" , None )
18561860 subfolder = kwargs .pop ("subfolder" , "" )
18571861 commit_hash = kwargs .pop ("_commit_hash" , None )
18581862
@@ -2156,13 +2160,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
21562160 model = cls (config , * model_args , ** model_kwargs )
21572161
21582162 if load_in_8bit :
2159- from .utils .bitsandbytes import get_key_to_not_convert , replace_8bit_linear
2163+ from .utils .bitsandbytes import get_keys_to_not_convert , replace_8bit_linear
21602164
21612165 logger .info ("Detected 8-bit loading: activating 8-bit loading for this model" )
21622166
2163- # We never convert lm_head or any last modules for numerical stability reasons
2164- modules_to_not_convert = get_key_to_not_convert (model )
2165- model = replace_8bit_linear (model , threshold = int8_threshold , modules_to_not_convert = modules_to_not_convert )
2167+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
2168+ if load_in_8bit_skip_modules is None :
2169+ modules_to_not_convert = get_keys_to_not_convert (model )
2170+ else :
2171+ modules_to_not_convert = load_in_8bit_skip_modules
2172+ model = replace_8bit_linear (
2173+ model , threshold = load_in_8bit_threshold , modules_to_not_convert = modules_to_not_convert
2174+ )
21662175
21672176 if isinstance (device_map , str ):
21682177 if model ._no_split_modules is None :
@@ -2193,12 +2202,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
21932202 )
21942203
21952204 if load_in_8bit :
2196- # The LM head can stay on disk / CPU
2205+ # The LM head / tied weights or any last module can stay on disk / CPU
21972206 device_map_without_lm_head = {
2198- key : device_map [key ] for key in device_map .keys () if key != modules_to_not_convert
2207+ key : device_map [key ] for key in device_map .keys () if key not in modules_to_not_convert
21992208 }
22002209 if "cpu" in device_map_without_lm_head .values () or "disk" in device_map_without_lm_head .values ():
2201- raise ValueError ("8-bit operations on `bitsandbytes` are not supported under CPU!" )
2210+ raise ValueError (
2211+ """
2212+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
2213+ the quantized model. If you have set a value for `max_memory` you should increase that. To have
2214+ an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
2215+ """
2216+ )
22022217 del device_map_without_lm_head
22032218
22042219 if from_tf :
0 commit comments