2323
2424import torch
2525import torch .utils .checkpoint
26- from torch import nn
2726from transformers .models .llava_next .modeling_llava_next import (
2827 LlavaNextCausalLMOutputWithPast ,
2928 LlavaNextForConditionalGeneration ,
@@ -84,7 +83,6 @@ def forward(
8483 use_cache = use_cache ,
8584 output_attentions = output_attentions ,
8685 output_hidden_states = output_hidden_states ,
87- return_dict = True ,
8886 cache_position = cache_position ,
8987 # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here
9088 # logits_to_keep=logits_to_keep,
@@ -94,31 +92,15 @@ def forward(
9492 ** kwargs ,
9593 )
9694
97- if inputs_embeds .shape [1 ] != 1 and pixel_values is not None and self .text_tokens_pos is not None :
98- batch_size , seq_len = self .text_tokens_pos .shape
99- batch_indices = torch .arange (batch_size ).repeat_interleave (seq_len )
100- logits = outputs [0 ][batch_indices , self .text_tokens_pos .reshape (- 1 ), :].reshape (
101- batch_size , seq_len , - 1
102- )
103- else :
104- logits = outputs [0 ]
95+ hidden_states = outputs [0 ]
96+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
97+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
98+ logits = self .lm_head (hidden_states [:, slice_indices , :])
10599
106100 loss = None
107101 if labels is not None :
108- # Shift so that tokens < n predict n
109- if attention_mask is not None :
110- # we use the input attention mask to shift the logits and labels, because it is 2D.
111- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
112- shift_attention_mask = attention_mask [:, - (logits .shape [1 ] - 1 ) :].to (logits .device )
113- shift_logits = logits [..., :- 1 , :][shift_attention_mask .to (logits .device ) != 0 ].contiguous ()
114- shift_labels = labels [..., 1 :][shift_attention_mask .to (labels .device ) != 0 ].contiguous ()
115- else :
116- shift_logits = logits [..., :- 1 , :].contiguous ()
117- shift_labels = labels [..., 1 :].contiguous ()
118- # Flatten the tokens
119- loss_fct = nn .CrossEntropyLoss ()
120- loss = loss_fct (
121- shift_logits .view (- 1 , shift_logits .size (- 1 )), shift_labels .view (- 1 ).to (shift_logits .device )
102+ loss = self .loss_function (
103+ logits = logits , labels = labels , vocab_size = self .config .text_config .vocab_size , ** kwargs
122104 )
123105
124106 return LlavaNextCausalLMOutputWithPast (
@@ -328,15 +310,15 @@ def prepare_inputs_for_generation(
328310 image_feature = torch .cat (
329311 (
330312 image_feature ,
331- self .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ),
313+ self .model . image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ),
332314 ),
333315 dim = - 1 ,
334316 )
335317 image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
336318 image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
337319 else :
338320 image_feature = image_feature [0 ]
339- image_feature = torch .cat ((image_feature , self .image_newline [None ]), dim = 0 )
321+ image_feature = torch .cat ((image_feature , self .model . image_newline [None ]), dim = 0 )
340322 new_image_features .append (image_feature )
341323 if legacy_processing :
342324 image_features = torch .stack (new_image_features , dim = 0 )
0 commit comments