Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
get_balanced_memory = None

if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -1747,6 +1747,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarly at the last position.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1839,6 +1842,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to have this be a class variable of PreTrainedModel (like the no_split variable used for big model inference)? I'm afraid the user won't know what to set this too and it looks like it's something we should automatically handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on that but this argument is optional because the function get_keys_not_to_convert should automatically take care of that except for some models like Jukebox where it is a bit trickier due to its architecture.
In this case the user will just have to manually set which modules should be kept in their native precision and specify them in the kwargs, so I feel like it is a bit easier than having it as an argument of PretrainedModel because you would need to open a PR to add the feature.

subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

Expand Down Expand Up @@ -2142,7 +2146,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
if no_load_in_8bit_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = no_load_in_8bit_modules
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)

if isinstance(device_map, str):
Expand Down Expand Up @@ -2174,12 +2181,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

if load_in_8bit:
# The LM head can stay on disk / CPU
# The LM head / tied weights or any last module can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized model. If you have set a value for `max_memory` you should increase that. To have
an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
"""
)
del device_map_without_lm_head

if from_tf:
Expand Down
21 changes: 15 additions & 6 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)

if isinstance(module, nn.Linear) and name != modules_to_not_convert:
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
Expand All @@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
return model


def get_key_to_not_convert(model):
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons.
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.

Parameters:
model (`torch.nn.Module`):
Expand All @@ -139,7 +141,9 @@ def get_key_to_not_convert(model):
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0

tied_keys = list(find_tied_parameters(tied_model).values())
has_tied_params = len(tied_keys) > 0

# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)
Expand All @@ -150,5 +154,10 @@ def get_key_to_not_convert(model):

# otherwise they have an attached head
list_modules = list(model.named_parameters())
last_name = list_modules[-1][0]
return last_name.split(".")[0]
list_last_module = [list_modules[-1][0]]

# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)

return [module_name.split(".")[0] for module_name in list_untouched]