Skip to content

Commit e9736b2

Browse files
committed
fix saving custom code
Former-commit-id: 3f8f40b
1 parent c61de6f commit e9736b2

2 files changed

Lines changed: 89 additions & 24 deletions

File tree

src/llmtuner/tuner/core/loader.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from transformers.utils import check_min_version
1212
from transformers.utils.versions import require_version
1313
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
14-
from transformers.tokenization_utils import PreTrainedTokenizer
14+
from transformers.tokenization_utils import PreTrainedTokenizerBase
1515
from trl import AutoModelForCausalLMWithValueHead
1616

1717
from llmtuner.extras.logging import get_logger
@@ -36,7 +36,7 @@ def load_model_and_tokenizer(
3636
finetuning_args: FinetuningArguments,
3737
is_trainable: Optional[bool] = False,
3838
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
39-
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
39+
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
4040
r"""
4141
Loads pretrained model and tokenizer.
4242
@@ -113,12 +113,12 @@ def load_model_and_tokenizer(
113113
)
114114

115115
# Register auto class to save the custom code files.
116-
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
116+
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
117117
config.__class__.register_for_auto_class()
118-
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
119-
tokenizer.__class__.register_for_auto_class()
120-
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
118+
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
121119
model.__class__.register_for_auto_class()
120+
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
121+
tokenizer.__class__.register_for_auto_class()
122122

123123
# Initialize adapters
124124
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model

tests/modeling_baichuan.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,45 @@ def _set_gradient_checkpointing(self, module, value=False):
300300
if isinstance(module, BaichuanModel):
301301
module.gradient_checkpointing = value
302302

303+
@staticmethod
304+
def _convert_to_standard_cache(
305+
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
306+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
307+
"""
308+
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
309+
num_heads, ...]))
310+
"""
311+
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
312+
num_heads = batch_size_times_num_heads // batch_size
313+
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
314+
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
315+
return tuple(
316+
(
317+
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
318+
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
319+
)
320+
for layer_past in past_key_value
321+
)
322+
323+
@staticmethod
324+
def _convert_to_baichuan_cache(
325+
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
326+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
327+
"""
328+
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
329+
"""
330+
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
331+
batch_size_times_num_heads = batch_size * num_heads
332+
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
333+
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
334+
return tuple(
335+
(
336+
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
337+
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
338+
)
339+
for layer_past in past_key_value
340+
)
341+
303342

304343
class BaichuanModel(BaichuanPreTrainedModel):
305344

@@ -318,9 +357,9 @@ def __init__(self, config: BaichuanConfig):
318357

319358
def get_input_embeddings(self):
320359
return self.embed_tokens
321-
360+
322361
def set_input_embeddings(self, value):
323-
self.embed_tokens = value
362+
self.embed_tokens = value
324363

325364
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
326365
return build_alibi_tensor(attention_mask, num_heads, dtype)
@@ -468,7 +507,7 @@ def custom_forward(*inputs):
468507
hidden_states=all_hidden_states,
469508
attentions=all_self_attns,
470509
)
471-
510+
472511

473512
class BaichuanForCausalLM(BaichuanPreTrainedModel):
474513

@@ -498,7 +537,7 @@ def set_decoder(self, decoder):
498537

499538
def get_decoder(self):
500539
return self.model
501-
540+
502541
def forward(
503542
self,
504543
input_ids: torch.LongTensor = None,
@@ -528,7 +567,7 @@ def forward(
528567
output_attentions=output_attentions,
529568
output_hidden_states=output_hidden_states,
530569
return_dict=return_dict,
531-
)
570+
)
532571

533572
hidden_states = outputs[0]
534573
logits = self.lm_head(hidden_states)
@@ -559,33 +598,59 @@ def forward(
559598
)
560599

561600
def prepare_inputs_for_generation(
562-
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
563-
):
601+
self,
602+
input_ids: torch.LongTensor,
603+
past_key_values: Optional[torch.Tensor] = None,
604+
attention_mask: Optional[torch.Tensor] = None,
605+
inputs_embeds: Optional[torch.Tensor] = None,
606+
**kwargs
607+
) -> dict:
564608
if past_key_values:
565609
input_ids = input_ids[:, -1:]
566610

611+
# the cache may be in the standard format (e.g. in contrastive search)
612+
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
613+
past_key_values = self._convert_to_baichuan_cache(past_key_values)
614+
567615
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
568616
if inputs_embeds is not None and past_key_values is None:
569617
model_inputs = {"inputs_embeds": inputs_embeds}
570618
else:
571619
model_inputs = {"input_ids": input_ids}
572620

573621
model_inputs.update(
574-
{
622+
{
575623
"past_key_values": past_key_values,
576624
"use_cache": kwargs.get("use_cache"),
577625
"attention_mask": attention_mask,
578-
}
579-
)
626+
}
627+
)
580628
return model_inputs
581629

582-
@staticmethod
583-
def _reorder_cache(past_key_values, beam_idx):
584-
return tuple(
585-
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
586-
for layer_past in past_key_values
630+
def _reorder_cache(
631+
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
632+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
633+
"""
634+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
635+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
636+
beam_idx at every generation step.
637+
638+
Output shares the same memory storage as `past`.
639+
"""
640+
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
641+
642+
# Get a copy of `beam_idx` on all the devices where we need those indices.
643+
device_to_beam_idx = {
644+
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
645+
}
646+
reordered_past = tuple(
647+
(
648+
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
649+
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
650+
)
651+
for layer_past in standardized_past
587652
)
588-
653+
return self._convert_to_baichuan_cache(reordered_past)
589654

590655
def quantize(self, bits: int):
591656
try:
@@ -594,7 +659,7 @@ def quantize(self, bits: int):
594659
raise ImportError(
595660
f"Needs QLinear to run quantize."
596661
)
597-
662+
598663
for layer in self.model.layers:
599664
layer.self_attn.W_pack = QLinear(
600665
bits=bits,
@@ -621,7 +686,7 @@ def quantize(self, bits: int):
621686
weight=layer.mlp.up_proj.weight,
622687
bias = None,
623688
)
624-
return self
689+
return self
625690

626691
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
627692
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens

0 commit comments

Comments
 (0)