Skip to content

Commit dac60e7

Browse files
authored
Fix LlavaNext functional issues (#2333)
1 parent 761b41b commit dac60e7

File tree

1 file changed

+8
-26
lines changed

1 file changed

+8
-26
lines changed

optimum/habana/transformers/models/llava_next/modeling_llava_next.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import torch
2525
import torch.utils.checkpoint
26-
from torch import nn
2726
from transformers.models.llava_next.modeling_llava_next import (
2827
LlavaNextCausalLMOutputWithPast,
2928
LlavaNextForConditionalGeneration,
@@ -84,7 +83,6 @@ def forward(
8483
use_cache=use_cache,
8584
output_attentions=output_attentions,
8685
output_hidden_states=output_hidden_states,
87-
return_dict=True,
8886
cache_position=cache_position,
8987
# TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here
9088
# logits_to_keep=logits_to_keep,
@@ -94,31 +92,15 @@ def forward(
9492
**kwargs,
9593
)
9694

97-
if inputs_embeds.shape[1] != 1 and pixel_values is not None and self.text_tokens_pos is not None:
98-
batch_size, seq_len = self.text_tokens_pos.shape
99-
batch_indices = torch.arange(batch_size).repeat_interleave(seq_len)
100-
logits = outputs[0][batch_indices, self.text_tokens_pos.reshape(-1), :].reshape(
101-
batch_size, seq_len, -1
102-
)
103-
else:
104-
logits = outputs[0]
95+
hidden_states = outputs[0]
96+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
97+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
98+
logits = self.lm_head(hidden_states[:, slice_indices, :])
10599

106100
loss = None
107101
if labels is not None:
108-
# Shift so that tokens < n predict n
109-
if attention_mask is not None:
110-
# we use the input attention mask to shift the logits and labels, because it is 2D.
111-
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
112-
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
113-
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
114-
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
115-
else:
116-
shift_logits = logits[..., :-1, :].contiguous()
117-
shift_labels = labels[..., 1:].contiguous()
118-
# Flatten the tokens
119-
loss_fct = nn.CrossEntropyLoss()
120-
loss = loss_fct(
121-
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
102+
loss = self.loss_function(
103+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
122104
)
123105

124106
return LlavaNextCausalLMOutputWithPast(
@@ -328,15 +310,15 @@ def prepare_inputs_for_generation(
328310
image_feature = torch.cat(
329311
(
330312
image_feature,
331-
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
313+
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
332314
),
333315
dim=-1,
334316
)
335317
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
336318
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
337319
else:
338320
image_feature = image_feature[0]
339-
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
321+
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
340322
new_image_features.append(image_feature)
341323
if legacy_processing:
342324
image_features = torch.stack(new_image_features, dim=0)

0 commit comments

Comments
 (0)