Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,14 @@ def _process_model_before_weight_loading(
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
if self.quantization_config.include_embedding:
if self.quantization_config.include_input_output_embeddings:
input_emb = model.get_input_embeddings()
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names]
output_emb = model.get_output_embeddings()
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
self.modules_to_not_convert = [
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
]
Comment on lines +191 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's a good idea to quantize the lm_head when the flag include_embedding is set 🤔 , it's a bit misleading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also:

def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, but it's still a nn.Linear not a nn.Embedding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About embeddings and lm_head, there are some edge cases we need to be aware of.
If they are tied:

  1. if we quantize the embeddings, the lm-head will also be quantized unless we break the tied weights. This will lead to reduce memory consumption but quality will be reduced.
  2. if we decide to remove the tied weights and quantize the embeddings / keep the lm_head as is, the memory consumption will increase (due to the lm-head) but maybe we have latency improvement ?. Maybe you also want to quantize the lm-head differently ?

Do we have a specific use case for 2) as I think this is what you wanted to do @jerryzh168 ?

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we have a use case in ExecuTorch, where we quantize both input embedding and lm_head, and we quantize them differently, the way we are doing it right now is:

(1) manually break ties
(2) quantize the input embedding and lm_head separately

see details in https://huggingface.co/pytorch/Phi-4-mini-instruct-8da4w#quantization-recipe

quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])

right now we need to set modules_to_not_convert and this PR will allow use to remove modules_to_not_convert

Also I feel we might be able to remove the untie_embedding_weights flag now since we have an alternative solution.

Please also take a look our solution for manually untying the weights, it might be useful to have some API for it as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MekkCyber how about changing the name to include_input_output_embeddings to be more specific on what we are referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it’s fine as long as the user is aware that they’re quantizing the lm_head.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, just updated

return

def check_quantized_param(
Expand All @@ -213,7 +217,7 @@ def check_quantized_param(
# we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name)
_QUANTIZABLE = [torch.nn.Linear]
if self.quantization_config.include_embedding:
if self.quantization_config.include_input_output_embeddings:
_QUANTIZABLE.append(torch.nn.Embedding)
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin):
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any]
include_embedding: bool
include_input_output_embeddings: bool
untie_embedding_weights: bool

"""This is a config class for torchao quantization/sparsity techniques.
Expand Down Expand Up @@ -1617,15 +1617,15 @@ def __init__(
self,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None,
include_embedding: bool = False,
include_input_output_embeddings: bool = False,
untie_embedding_weights: bool = False,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.include_embedding = include_embedding
self.include_input_output_embeddings = include_input_output_embeddings
self.untie_embedding_weights = untie_embedding_weights
self.post_init()

Expand Down
11 changes: 7 additions & 4 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_int8_dynamic_activation_int8_weight_quant(self):
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.11.0")
def test_include_embedding(self):
def test_include_input_output_embeddings(self):
weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -210,16 +210,19 @@ def test_include_embedding(self):
granularity=granularity,
mapping_type=mapping_type,
)
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
# need set `include_embedding` to True
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
config = AOPerModuleConfig(
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
)
# need set `include_input_output_embeddings` to True
quant_config = TorchAoConfig(quant_type=config, include_input_output_embeddings=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure embedding is quantized
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.lm_head.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand Down