Skip to content

Commit 44fa04a

Browse files
Include output embedding as well with include_embedding flag (#37935)
* Include output embedding as well with `include_embedding` flag Summary: att Test Plan: python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding Reviewers: Subscribers: Tasks: Tags: * format * rename include_embedding to include_input_output_embeddings --------- Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 34c1e29 commit 44fa04a

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,14 @@ def _process_model_before_weight_loading(
185185
self.modules_to_not_convert = self.get_modules_to_not_convert(
186186
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
187187
)
188-
if self.quantization_config.include_embedding:
188+
if self.quantization_config.include_input_output_embeddings:
189189
input_emb = model.get_input_embeddings()
190190
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
191-
self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names]
191+
output_emb = model.get_output_embeddings()
192+
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
193+
self.modules_to_not_convert = [
194+
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
195+
]
192196
return
193197

194198
def check_quantized_param(
@@ -213,7 +217,7 @@ def check_quantized_param(
213217
# we only quantize the weight of nn.Linear and nn.Embedding
214218
module, tensor_name = get_module_from_name(model, param_name)
215219
_QUANTIZABLE = [torch.nn.Linear]
216-
if self.quantization_config.include_embedding:
220+
if self.quantization_config.include_input_output_embeddings:
217221
_QUANTIZABLE.append(torch.nn.Embedding)
218222
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
219223

src/transformers/utils/quantization_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin):
15541554
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
15551555
modules_to_not_convert: Optional[List]
15561556
quant_type_kwargs: Dict[str, Any]
1557-
include_embedding: bool
1557+
include_input_output_embeddings: bool
15581558
untie_embedding_weights: bool
15591559

15601560
"""This is a config class for torchao quantization/sparsity techniques.
@@ -1617,15 +1617,15 @@ def __init__(
16171617
self,
16181618
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
16191619
modules_to_not_convert: Optional[List] = None,
1620-
include_embedding: bool = False,
1620+
include_input_output_embeddings: bool = False,
16211621
untie_embedding_weights: bool = False,
16221622
**kwargs,
16231623
):
16241624
self.quant_method = QuantizationMethod.TORCHAO
16251625
self.quant_type = quant_type
16261626
self.modules_to_not_convert = modules_to_not_convert
16271627
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
1628-
self.include_embedding = include_embedding
1628+
self.include_input_output_embeddings = include_input_output_embeddings
16291629
self.untie_embedding_weights = untie_embedding_weights
16301630
self.post_init()
16311631

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_int8_dynamic_activation_int8_weight_quant(self):
201201
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
202202

203203
@require_torchao_version_greater_or_equal("0.11.0")
204-
def test_include_embedding(self):
204+
def test_include_input_output_embeddings(self):
205205
weight_dtype = torch.int8
206206
granularity = PerAxis(0)
207207
mapping_type = MappingType.ASYMMETRIC
@@ -210,16 +210,19 @@ def test_include_embedding(self):
210210
granularity=granularity,
211211
mapping_type=mapping_type,
212212
)
213-
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
214-
# need set `include_embedding` to True
215-
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
213+
config = AOPerModuleConfig(
214+
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
215+
)
216+
# need set `include_input_output_embeddings` to True
217+
quant_config = TorchAoConfig(quant_type=config, include_input_output_embeddings=True)
216218
quantized_model = AutoModelForCausalLM.from_pretrained(
217219
self.model_name,
218220
device_map=self.device,
219221
quantization_config=quant_config,
220222
)
221223
# making sure embedding is quantized
222224
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
225+
self.assertTrue(isinstance(quantized_model.lm_head.weight, AffineQuantizedTensor))
223226
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
224227

225228
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

0 commit comments

Comments
 (0)