Skip to content

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented May 2, 2025

Summary:
att

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_input_output_embeddings

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
att

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding

Reviewers:

Subscribers:

Tasks:

Tags:
@github-actions github-actions bot marked this pull request as draft May 2, 2025 23:46
@github-actions
Copy link
Contributor

github-actions bot commented May 2, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@jerryzh168
Copy link
Contributor Author

cc @MekkCyber @SunMarc please take a look, just a small change to include_embedding

@jerryzh168 jerryzh168 marked this pull request as ready for review May 3, 2025 00:22
@github-actions github-actions bot requested review from MekkCyber and SunMarc May 3, 2025 00:22
Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for PR @jerryzh168, I have a small concern with the output embeddings quantization

Comment on lines +191 to +195
output_emb = model.get_output_embeddings()
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
self.modules_to_not_convert = [
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's a good idea to quantize the lm_head when the flag include_embedding is set 🤔 , it's a bit misleading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also:

def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, but it's still a nn.Linear not a nn.Embedding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About embeddings and lm_head, there are some edge cases we need to be aware of.
If they are tied:

  1. if we quantize the embeddings, the lm-head will also be quantized unless we break the tied weights. This will lead to reduce memory consumption but quality will be reduced.
  2. if we decide to remove the tied weights and quantize the embeddings / keep the lm_head as is, the memory consumption will increase (due to the lm-head) but maybe we have latency improvement ?. Maybe you also want to quantize the lm-head differently ?

Do we have a specific use case for 2) as I think this is what you wanted to do @jerryzh168 ?

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we have a use case in ExecuTorch, where we quantize both input embedding and lm_head, and we quantize them differently, the way we are doing it right now is:

(1) manually break ties
(2) quantize the input embedding and lm_head separately

see details in https://huggingface.co/pytorch/Phi-4-mini-instruct-8da4w#quantization-recipe

quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])

right now we need to set modules_to_not_convert and this PR will allow use to remove modules_to_not_convert

Also I feel we might be able to remove the untie_embedding_weights flag now since we have an alternative solution.

Please also take a look our solution for manually untying the weights, it might be useful to have some API for it as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MekkCyber how about changing the name to include_input_output_embeddings to be more specific on what we are referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it’s fine as long as the user is aware that they’re quantizing the lm_head.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, just updated

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some feedback !

Comment on lines +191 to +195
output_emb = model.get_output_embeddings()
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
self.modules_to_not_convert = [
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About embeddings and lm_head, there are some edge cases we need to be aware of.
If they are tied:

  1. if we quantize the embeddings, the lm-head will also be quantized unless we break the tied weights. This will lead to reduce memory consumption but quality will be reduced.
  2. if we decide to remove the tied weights and quantize the embeddings / keep the lm_head as is, the memory consumption will increase (due to the lm-head) but maybe we have latency improvement ?. Maybe you also want to quantize the lm-head differently ?

Do we have a specific use case for 2) as I think this is what you wanted to do @jerryzh168 ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this !

@jerryzh168
Copy link
Contributor Author

@MekkCyber @SunMarc can you merge this

@MekkCyber MekkCyber enabled auto-merge (squash) May 16, 2025 09:12
@MekkCyber MekkCyber disabled auto-merge May 16, 2025 09:43
@MekkCyber MekkCyber merged commit 44fa04a into huggingface:main May 16, 2025
20 checks passed
@MekkCyber
Copy link
Contributor

Done ! sorry for the delay

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants