-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: remove most decoder-only LLMs prepare_inputs_for_generation
#33870
base: main
Are you sure you want to change the base?
Conversation
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
@@ -350,47 +350,69 @@ def prepare_inputs_for_generation( | |||
attention_mask: Optional[torch.LongTensor] = None, | |||
inputs_embeds: Optional[torch.FloatTensor] = None, | |||
cache_position: Optional[torch.LongTensor] = None, | |||
position_ids: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all models expect this one. We now inspect the signature to determine whether we need to generate them on the fly
use_cache: bool = True, | ||
num_logits_to_keep: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are moved to kwargs
. We now forward kwargs
to the model inputs :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is up to mark working efficiently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, so much code killed, thanks!
# Overwritten -- model logic breaks when `inputs_embeds` are passed from this function | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed? I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings. So interesting to see why this model failed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings.
Precisely! Many models will get this feature for free as part of these deletions 💛
Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed?
No clue, I didn't dive deeper :) Failed in inputs_embeds
tests -> pasted this comment. I don't think these combos of model/feature are worth the dive, so I left this low-information (but better than nothing) note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the test was just flaky! I've added flakiness protection to the failing test and deleted a few more cases :)
@unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results") | ||
def test_static_cache_matches_dynamic(self): | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was marked flaky for VLMs in one of the other PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this PR, it becomes a failure all the times 👀 I have no idea why (didn't dive)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super sad, i started diving a while ago and that seems related to paligemma's weird masking for prefix/suffix. I'll see if I can get time to spot the bug
@@ -2837,7 +2837,7 @@ def test_inputs_embeds_matches_input_ids(self): | |||
|
|||
def test_inputs_embeds_matches_input_ids_with_generate(self): | |||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | |||
for model_class in self.all_model_classes: | |||
for model_class in self.all_generative_model_classes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this test calls generate
)
if ( | ||
attention_mask is not None | ||
and kwargs.get("position_ids") is None | ||
and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick Q, how fast is this / is it slowing down generation?
- we can store the inspect result if needed otherwise!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not too bad, but can be improved, yes. On my machine, this adds 0.024ms per generated token (small, but not negligible). If we cache the inspect.signature
, we reduce it by 100x.
We actually make several inspect.signature(foward)
calls in generate
and other bits of the codebase, I think it makes sense to store the inspect as a cached model property (e.g. model.forward_signature
). WDYT? If you agree, I'll open a follow-up PR with this change
For completeness, script to measure the impact of caching this call:
import time
import inspect
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# Fresh inspect
all_times = []
for _ in range(1000):
start = time.time()
"position_ids" in set(inspect.signature(model.forward).parameters.keys())
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))
# Cached inspect
signature_keys = set(inspect.signature(model.forward).parameters.keys())
all_times = []
for _ in range(1000):
start = time.time()
"position_ids" in signature_keys
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense
|
||
# 4. Create missing `position_ids` on the fly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
): | ||
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seen in other PRs, that it needed to be sliced to seq_length no? -seq_len:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, slicing happens in the code block after this one. That code block abstracts slicing to other input names (e.g. token_type_ids
needs to be sliced exactly like position_ids
-- and we can add more to this list as needed 🤗 )
for key, value in kwargs.items(): | ||
if key not in model_inputs: | ||
model_inputs[key] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure this is super efficient TBH!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its run time is negligible, even if kwargs
contains a handful of entries (usually it will only contain one or two). At most 0.001 ms per call :P
On the plus side, this code block will allow us to generalize this function to VLMs 😉 I think that's worth the super small cost.
import time
import torch
all_times = []
for _ in range(1000):
model_inputs = {str(i): i for i in range(10)}
kwargs = {'a': 1, 'b': 2, 'c': torch.zeros((100, 100)), "0": 12, "1": 3546}
start = time.time()
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay good for me, let's fix generate tests if related
What does this PR do?
Part of step 6 in #32685
Follow-up to #33677
This PR:
GenerationMixin.prepare_inputs_for_generation
so as to handle models WITHOUT theCache
refactor, preparetoken_type_ids
, and forward arbitrary kwargs✅ slow tests were ran on
llama
andgpt2