Skip to content

Commit

Permalink
ok, this one doesn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Oct 3, 2024
1 parent cf76af6 commit a09a4b2
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,24 @@ def forward(
attentions=transformer_outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
# Overwritten -- inputs_embeds not working properly

# only last tokens for inputs_ids if past is defined in kwargs
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}

@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
Expand Down

0 comments on commit a09a4b2

Please sign in to comment.