-
Notifications
You must be signed in to change notification settings - Fork 32.6k
[bnb] Small improvements on utils #18646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
ea155ed
a7731f7
bf59f9f
f5dc6ad
42c9df2
a84aaa7
27b0ef0
224b504
01a4c0c
23fe74a
c266e23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,7 +85,7 @@ | |
| get_balanced_memory = None | ||
|
|
||
| if is_bitsandbytes_available(): | ||
| from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device | ||
| from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
@@ -1747,6 +1747,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| quantization works well for values of magnitude ~5, but beyond that, there is a significant performance | ||
| penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models | ||
| (small models, fine-tuning). | ||
| no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`): | ||
| An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such | ||
| as Jukebox that has several heads in different places and not necessarly at the last position. | ||
| subfolder (`str`, *optional*, defaults to `""`): | ||
| In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | ||
| specify the folder name here. | ||
|
|
@@ -1839,6 +1842,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| offload_state_dict = kwargs.pop("offload_state_dict", False) | ||
| load_in_8bit = kwargs.pop("load_in_8bit", False) | ||
| int8_threshold = kwargs.pop("int8_threshold", 6.0) | ||
| no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None) | ||
|
||
| subfolder = kwargs.pop("subfolder", "") | ||
| commit_hash = kwargs.pop("_commit_hash", None) | ||
|
|
||
|
|
@@ -2142,7 +2146,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| logger.info("Detected 8-bit loading: activating 8-bit loading for this model") | ||
|
|
||
| # We never convert lm_head or any last modules for numerical stability reasons | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| modules_to_not_convert = get_key_to_not_convert(model) | ||
| if no_load_in_8bit_modules is None: | ||
| modules_to_not_convert = get_keys_to_not_convert(model) | ||
| else: | ||
| modules_to_not_convert = no_load_in_8bit_modules | ||
| model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) | ||
|
|
||
| if isinstance(device_map, str): | ||
|
|
@@ -2174,12 +2181,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| ) | ||
|
|
||
| if load_in_8bit: | ||
| # The LM head can stay on disk / CPU | ||
| # The LM head / tied weights or any last module can stay on disk / CPU | ||
| device_map_without_lm_head = { | ||
| key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert | ||
| key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert | ||
| } | ||
| if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): | ||
| raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") | ||
| raise ValueError( | ||
| """ | ||
| Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit | ||
| the quantized model. If you have set a value for `max_memory` you should increase that. To have | ||
| an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map. | ||
| """ | ||
| ) | ||
| del device_map_without_lm_head | ||
|
|
||
| if from_tf: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.