@@ -32,17 +32,40 @@ def prediction_step(
3232 Subclass and override to inject custom behavior.
3333 """
3434 prompt_len , label_len = inputs ["input_ids" ].size (- 1 ), inputs ["labels" ].size (- 1 )
35- if self .tokenizer .padding_side == "right" : # pads the labels to the same length as the inputs
36- inputs ["labels" ] = torch .cat ((inputs ["labels" ], torch .zeros_like (inputs ["input_ids" ])[:, label_len :]), dim = - 1 )
37- else :
38- inputs ["labels" ] = torch .cat ((torch .zeros_like (inputs ["input_ids" ])[:, label_len :], inputs ["labels" ]), dim = - 1 )
35+ if prompt_len > label_len :
36+ inputs ["labels" ] = self ._pad_tensors_to_target_len (inputs ["labels" ], inputs ["input_ids" ])
37+ if label_len > prompt_len :
38+ inputs ["input_ids" ] = self ._pad_tensors_to_target_len (inputs ["input_ids" ], inputs ["labels" ])
39+
3940 loss , generated_tokens , labels = super ().prediction_step (
4041 model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys
4142 )
42- generated_tokens = generated_tokens [:, prompt_len :] if generated_tokens is not None else None
43+ generated_tokens = generated_tokens [:, max ( prompt_len , label_len ) :] if generated_tokens is not None else None
4344
4445 return (loss , generated_tokens , labels )
4546
47+ def _pad_tensors_to_target_len (self , src_tensor : torch .Tensor , tgt_tensor : torch .Tensor ) -> torch .Tensor :
48+ r"""
49+ Pads the tensor to the same length as the target tensor.
50+
51+ Should only be called when predict_with_generate=True.
52+ """
53+ if self .tokenizer is not None and hasattr (self .tokenizer , "pad_token_id" ):
54+ assert self .tokenizer .padding_side == "left" , "This method only accepts left-padded tensor."
55+ # If PAD token is not defined at least EOS token has to be defined
56+ pad_token_id = (
57+ self .tokenizer .pad_token_id if self .tokenizer .pad_token_id is not None else self .tokenizer .eos_token_id
58+ )
59+ else :
60+ if self .model .config .pad_token_id is not None :
61+ pad_token_id = self .model .config .pad_token_id
62+ else :
63+ raise ValueError ("Pad_token_id must be set in the configuration of the model, in order to pad tensors" )
64+
65+ padded_tensor = pad_token_id * torch .ones_like (tgt_tensor )
66+ padded_tensor [:, - src_tensor .shape [- 1 ]:] = src_tensor # adopt left-padding
67+ return padded_tensor
68+
4669 def save_predictions (
4770 self ,
4871 predict_results : PredictionOutput
0 commit comments