Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: remove most decoder-only LLMs prepare_inputs_for_generation #33870

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Oct 1, 2024

What does this PR do?

Part of step 6 in #32685
Follow-up to #33677

This PR:

  1. revises GenerationMixin.prepare_inputs_for_generation so as to handle models WITHOUT the Cache refactor, prepare token_type_ids, and forward arbitrary kwargs
  2. because of 1., we can remove this function from most decoder-only LLMs 🧹🤗🧹 All decoder-only LLMs were checked
  3. added a comment on each overwrite occurring in decoder-only LLMs, for our future selves

✅ slow tests were ran on llama and gpt2

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.

@@ -350,47 +350,69 @@ def prepare_inputs_for_generation(
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all models expect this one. We now inspect the signature to determine whether we need to generate them on the fly

Comment on lines -354 to -356
use_cache: bool = True,
num_logits_to_keep: Optional[int] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are moved to kwargs. We now forward kwargs to the model inputs :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@akshit397a akshit397a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is up to mark working efficiently

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, so much code killed, thanks!

Comment on lines 1583 to 1584
# Overwritten -- model logic breaks when `inputs_embeds` are passed from this function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed? I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings. So interesting to see why this model failed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings.

Precisely! Many models will get this feature for free as part of these deletions 💛

Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed?

No clue, I didn't dive deeper :) Failed in inputs_embeds tests -> pasted this comment. I don't think these combos of model/feature are worth the dive, so I left this low-information (but better than nothing) note

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the test was just flaky! I've added flakiness protection to the failing test and deleted a few more cases :)

Comment on lines +317 to +320
@unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results")
def test_static_cache_matches_dynamic(self):
pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was marked flaky for VLMs in one of the other PRs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this PR, it becomes a failure all the times 👀 I have no idea why (didn't dive)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super sad, i started diving a while ago and that seems related to paligemma's weird masking for prefix/suffix. I'll see if I can get time to spot the bug

@@ -2837,7 +2837,7 @@ def test_inputs_embeds_matches_input_ids(self):

def test_inputs_embeds_matches_input_ids_with_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_generative_model_classes:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this test calls generate)

if (
attention_mask is not None
and kwargs.get("position_ids") is None
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick Q, how fast is this / is it slowing down generation?

  • we can store the inspect result if needed otherwise!

Copy link
Member Author

@gante gante Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not too bad, but can be improved, yes. On my machine, this adds 0.024ms per generated token (small, but not negligible). If we cache the inspect.signature, we reduce it by 100x.

We actually make several inspect.signature(foward) calls in generate and other bits of the codebase, I think it makes sense to store the inspect as a cached model property (e.g. model.forward_signature). WDYT? If you agree, I'll open a follow-up PR with this change

For completeness, script to measure the impact of caching this call:

import time
import inspect
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("distilgpt2")

# Fresh inspect
all_times = []
for _ in range(1000):
    start = time.time()
    "position_ids" in set(inspect.signature(model.forward).parameters.keys())
    all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))

# Cached inspect
signature_keys = set(inspect.signature(model.forward).parameters.keys())
all_times = []
for _ in range(1000):
    start = time.time()
    "position_ids" in signature_keys
    all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense


# 4. Create missing `position_ids` on the fly
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seen in other PRs, that it needed to be sliced to seq_length no? -seq_len:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, slicing happens in the code block after this one. That code block abstracts slicing to other input names (e.g. token_type_ids needs to be sliced exactly like position_ids -- and we can add more to this list as needed 🤗 )

Comment on lines +452 to +454
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure this is super efficient TBH!

Copy link
Member Author

@gante gante Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its run time is negligible, even if kwargs contains a handful of entries (usually it will only contain one or two). At most 0.001 ms per call :P

On the plus side, this code block will allow us to generalize this function to VLMs 😉 I think that's worth the super small cost.

import time
import torch

all_times = []
for _ in range(1000):

    model_inputs = {str(i): i for i in range(10)}
    kwargs = {'a': 1, 'b': 2, 'c': torch.zeros((100, 100)), "0": 12, "1": 3546}

    start = time.time()
    for key, value in kwargs.items():
        if key not in model_inputs:
            model_inputs[key] = value
    all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay good for me, let's fix generate tests if related

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants