@@ -300,6 +300,45 @@ def _set_gradient_checkpointing(self, module, value=False):
300300 if isinstance (module , BaichuanModel ):
301301 module .gradient_checkpointing = value
302302
303+ @staticmethod
304+ def _convert_to_standard_cache (
305+ past_key_value : Tuple [Tuple [torch .Tensor , torch .Tensor ]], batch_size : int
306+ ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ]]:
307+ """
308+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
309+ num_heads, ...]))
310+ """
311+ batch_size_times_num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
312+ num_heads = batch_size_times_num_heads // batch_size
313+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
314+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
315+ return tuple (
316+ (
317+ layer_past [0 ].view (batch_size , num_heads , head_dim , seq_length ),
318+ layer_past [1 ].view (batch_size , num_heads , seq_length , head_dim ),
319+ )
320+ for layer_past in past_key_value
321+ )
322+
323+ @staticmethod
324+ def _convert_to_baichuan_cache (
325+ past_key_value : Tuple [Tuple [torch .Tensor , torch .Tensor ]]
326+ ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ]]:
327+ """
328+ Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
329+ """
330+ batch_size , num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
331+ batch_size_times_num_heads = batch_size * num_heads
332+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
333+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
334+ return tuple (
335+ (
336+ layer_past [0 ].view (batch_size_times_num_heads , head_dim , seq_length ),
337+ layer_past [1 ].view (batch_size_times_num_heads , seq_length , head_dim ),
338+ )
339+ for layer_past in past_key_value
340+ )
341+
303342
304343class BaichuanModel (BaichuanPreTrainedModel ):
305344
@@ -318,9 +357,9 @@ def __init__(self, config: BaichuanConfig):
318357
319358 def get_input_embeddings (self ):
320359 return self .embed_tokens
321-
360+
322361 def set_input_embeddings (self , value ):
323- self .embed_tokens = value
362+ self .embed_tokens = value
324363
325364 def build_alibi_tensor (self , attention_mask : torch .Tensor , num_heads : int , dtype : torch .dtype ) -> torch .Tensor :
326365 return build_alibi_tensor (attention_mask , num_heads , dtype )
@@ -468,7 +507,7 @@ def custom_forward(*inputs):
468507 hidden_states = all_hidden_states ,
469508 attentions = all_self_attns ,
470509 )
471-
510+
472511
473512class BaichuanForCausalLM (BaichuanPreTrainedModel ):
474513
@@ -498,7 +537,7 @@ def set_decoder(self, decoder):
498537
499538 def get_decoder (self ):
500539 return self .model
501-
540+
502541 def forward (
503542 self ,
504543 input_ids : torch .LongTensor = None ,
@@ -528,7 +567,7 @@ def forward(
528567 output_attentions = output_attentions ,
529568 output_hidden_states = output_hidden_states ,
530569 return_dict = return_dict ,
531- )
570+ )
532571
533572 hidden_states = outputs [0 ]
534573 logits = self .lm_head (hidden_states )
@@ -559,33 +598,59 @@ def forward(
559598 )
560599
561600 def prepare_inputs_for_generation (
562- self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , ** kwargs
563- ):
601+ self ,
602+ input_ids : torch .LongTensor ,
603+ past_key_values : Optional [torch .Tensor ] = None ,
604+ attention_mask : Optional [torch .Tensor ] = None ,
605+ inputs_embeds : Optional [torch .Tensor ] = None ,
606+ ** kwargs
607+ ) -> dict :
564608 if past_key_values :
565609 input_ids = input_ids [:, - 1 :]
566610
611+ # the cache may be in the standard format (e.g. in contrastive search)
612+ if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
613+ past_key_values = self ._convert_to_baichuan_cache (past_key_values )
614+
567615 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
568616 if inputs_embeds is not None and past_key_values is None :
569617 model_inputs = {"inputs_embeds" : inputs_embeds }
570618 else :
571619 model_inputs = {"input_ids" : input_ids }
572620
573621 model_inputs .update (
574- {
622+ {
575623 "past_key_values" : past_key_values ,
576624 "use_cache" : kwargs .get ("use_cache" ),
577625 "attention_mask" : attention_mask ,
578- }
579- )
626+ }
627+ )
580628 return model_inputs
581629
582- @staticmethod
583- def _reorder_cache (past_key_values , beam_idx ):
584- return tuple (
585- tuple (past_state .index_select (0 , beam_idx ) for past_state in layer_past )
586- for layer_past in past_key_values
630+ def _reorder_cache (
631+ self , past : Tuple [Tuple [torch .Tensor , torch .Tensor ], ...], beam_idx : torch .LongTensor
632+ ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ], ...]:
633+ """
634+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
635+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
636+ beam_idx at every generation step.
637+
638+ Output shares the same memory storage as `past`.
639+ """
640+ standardized_past = self ._convert_to_standard_cache (past , batch_size = len (beam_idx ))
641+
642+ # Get a copy of `beam_idx` on all the devices where we need those indices.
643+ device_to_beam_idx = {
644+ past_state .device : beam_idx .to (past_state .device ) for layer_past in past for past_state in layer_past
645+ }
646+ reordered_past = tuple (
647+ (
648+ layer_past [0 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
649+ layer_past [1 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
650+ )
651+ for layer_past in standardized_past
587652 )
588-
653+ return self . _convert_to_baichuan_cache ( reordered_past )
589654
590655 def quantize (self , bits : int ):
591656 try :
@@ -594,7 +659,7 @@ def quantize(self, bits: int):
594659 raise ImportError (
595660 f"Needs QLinear to run quantize."
596661 )
597-
662+
598663 for layer in self .model .layers :
599664 layer .self_attn .W_pack = QLinear (
600665 bits = bits ,
@@ -621,7 +686,7 @@ def quantize(self, bits: int):
621686 weight = layer .mlp .up_proj .weight ,
622687 bias = None ,
623688 )
624- return self
689+ return self
625690
626691 def _build_chat_input (self , tokenizer , messages : List [dict ], max_new_tokens : int = 0 ):
627692 max_new_tokens = max_new_tokens or self .generation_config .max_new_tokens
0 commit comments